{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np # get rid of this eventually\n",
    "import argparse\n",
    "from jax import jit\n",
    "from jax.experimental.ode import odeint\n",
    "from functools import partial # reduces arguments to function by making some subset implicit\n",
    "\n",
    "from jax.experimental import stax\n",
    "from jax.experimental import optimizers\n",
    "\n",
    "import os, sys, time\n",
    "sys.path.append('..')\n",
    "sys.path.append('../hyperopt')\n",
    "from HyperparameterSearch import extended_mlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import mlp as make_mlp\n",
    "from utils import wrap_coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hamiltonian_eom(hamiltonian, state, conditionals, t=None):\n",
    "    q, q_t = jnp.split(state, 2)\n",
    "    \n",
    "    #Move to canonical coordinates:\n",
    "    p = q_t/(1.0-q_t**2)**(3/2.0)\n",
    "    q = q\n",
    "    \n",
    "    conditionals = conditionals / 10.0 #Normalize\n",
    "    p_t = -jax.grad(hamiltonian, 0)(q, p, conditionals)\n",
    "    #Move back to generalized coordinates:\n",
    "    q_tt = p_t*(1-q_t**2)**(5.0/2)/(1.0+2*q_t**2)\n",
    "    \n",
    "    #Avoid nans by computing q_t afterwards:\n",
    "    q_t = jax.grad(hamiltonian, 1)(q, p, conditionals)\n",
    "    return jnp.concatenate([q_t, q_tt])\n",
    "\n",
    "# replace the lagrangian with a parameteric model\n",
    "def learned_dynamics(params, nn_forward_fn):\n",
    "    @jit\n",
    "    def dynamics(q, p, conditionals):\n",
    "        state = jnp.concatenate([q, p, conditionals])\n",
    "        return jnp.squeeze(nn_forward_fn(params, state), axis=-1)\n",
    "    return dynamics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ObjectView(object):\n",
    "    def __init__(self, d): self.__dict__ = d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def qdotdot(q, q_t, conditionals):\n",
    "    g = conditionals\n",
    "    \n",
    "    q_tt = (\n",
    "        g * (1 - q_t**2)**(5./2) / \n",
    "        (1 + 2 * q_t**2)\n",
    "    )\n",
    "    \n",
    "    return q_t, q_tt\n",
    "\n",
    "@jax.jit\n",
    "def ofunc(y, t=None):\n",
    "    q = y[::3]\n",
    "    q_t = y[1::3]\n",
    "    g = y[2::3]\n",
    "    \n",
    "    q_t, q_tt = qdotdot(q, q_t, g)\n",
    "    return jnp.stack([q_t, q_tt, jnp.zeros_like(g)]).T.ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(0.99989885, dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(jnp.tanh(jax.random.uniform(jax.random.PRNGKey(1), (1000,))*10-5)*0.99999).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([30.,  8.,  5.,  8.,  3.,  2.,  5.,  3.,  5., 31.]),\n",
       " array([-9.9991339e-01, -7.9992497e-01, -5.9993654e-01, -3.9994809e-01,\n",
       "        -1.9995967e-01,  2.8759241e-05,  2.0001718e-01,  4.0000561e-01,\n",
       "         5.9999406e-01,  7.9998249e-01,  9.9997091e-01], dtype=float32),\n",
       " <a list of 10 Patch objects>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPx0lEQVR4nO3dfaxkd13H8feHLhQVtFt6qUth2RYr0sSwJTe1sQkP5amUhJZYdJuAi9YsIBiImLiAiWg0tkZoYiTgYmtXxfJQaLoKiEsfQkigeIul3bIpuy1VS9furaU8xFhp+/WPORfHu3N35t55uPuz71cymTO/8zvnfO9vZj975sw5M6kqJEntecJ6FyBJWhsDXJIaZYBLUqMMcElqlAEuSY3aMMuNnXTSSbVly5ZZblKSmnfLLbc8UFVzy9tnGuBbtmxhYWFhlpuUpOYl+ZdB7R5CkaRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRs30SkxJWk9bdn563bZ9z6Wvnvg63QOXpEYZ4JLUqKEBnuTJSb6S5GtJ7kjye137qUluTnIgyceSPGn65UqSloyyB/4wcG5VPR/YCpyX5GzgMuDyqjod+DZwyfTKlCQtNzTAq+f73cMndrcCzgWu6dp3AxdOpUJJ0kAjnYWS5DjgFuCngA8AdwEPVdUjXZd7gVNWWHYHsANg8+bNay70/9unx5I0rpE+xKyqR6tqK/BM4CzgeYO6rbDsrqqar6r5ubkjflBCkrRGqzoLpaoeAm4CzgZOSLK0B/9M4L7JliZJOppRzkKZS3JCN/0jwMuA/cCNwEVdt+3AddMqUpJ0pFGOgW8CdnfHwZ8AfLyq/j7J14GPJvkD4J+BK6ZYpyRpmaEBXlW3AWcOaL+b3vFwSdI68EpMSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWrU0ABP8qwkNybZn+SOJG/v2t+b5FtJbu1u50+/XEnSkg0j9HkEeGdVfTXJU4Fbkuzt5l1eVX8yvfIkSSsZGuBVdQg41E1/L8l+4JRpFyZJOrpVHQNPsgU4E7i5a3pbktuSXJlk4wrL7EiykGRhcXFxrGIlSf9r5ABP8hTgk8A7quq7wAeB5wBb6e2hv2/QclW1q6rmq2p+bm5uAiVLkmDEAE/yRHrh/ZGq+hRAVd1fVY9W1WPAh4GzplemJGm5Uc5CCXAFsL+q3t/Xvqmv22uBfZMvT5K0klHOQjkHeANwe5Jbu7Z3Axcn2QoUcA/wpqlUKEkaaJSzUL4IZMCsz0y+HEnSqLwSU5IaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1KihAZ7kWUluTLI/yR1J3t61n5hkb5ID3f3G6ZcrSVoyyh74I8A7q+p5wNnAW5OcAewErq+q04Hru8eSpBkZGuBVdaiqvtpNfw/YD5wCXADs7rrtBi6cVpGSpCOt6hh4ki3AmcDNwMlVdQh6IQ88fYVldiRZSLKwuLg4XrWSpB8aOcCTPAX4JPCOqvruqMtV1a6qmq+q+bm5ubXUKEkaYKQAT/JEeuH9kar6VNd8f5JN3fxNwOHplChJGmSUs1ACXAHsr6r3983aA2zvprcD102+PEnSSjaM0Occ4A3A7Ulu7dreDVwKfDzJJcC/Aq+bTomSpEGGBnhVfRHICrNfOtlyJEmj8kpMSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSo4YGeJIrkxxOsq+v7b1JvpXk1u52/nTLlCQtN8oe+FXAeQPaL6+qrd3tM5MtS5I0zNAAr6ovAA/OoBZJ0iqMcwz8bUlu6w6xbJxYRZKkkaw1wD8IPAfYChwC3rdSxyQ7kiwkWVhcXFzj5iRJy60pwKvq/qp6tKoeAz4MnHWUvruqar6q5ufm5tZapyRpmTUFeJJNfQ9fC+xbqa8kaTo2DOuQ5GrgxcBJSe4Ffhd4cZKtQAH3AG+aYo2SpAGGBnhVXTyg+Yop1CJJWgWvxJSkRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRg0N8CRXJjmcZF9f24lJ9iY50N1vnG6ZkqTlRtkDvwo4b1nbTuD6qjoduL57LEmaoaEBXlVfAB5c1nwBsLub3g1cOOG6JElDrPUY+MlVdQigu3/6Sh2T7EiykGRhcXFxjZuTJC039Q8xq2pXVc1X1fzc3Ny0NydJjxtrDfD7k2wC6O4PT64kSdIo1hrge4Dt3fR24LrJlCNJGtUopxFeDXwJeG6Se5NcAlwKvDzJAeDl3WNJ0gxtGNahqi5eYdZLJ1yLJGkVvBJTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGDT0PXLBl56fXZbv3XPrqddnuev29sH5/s9Qi98AlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0a6wcdktwDfA94FHikquYnUZQkabhJ/CLPS6rqgQmsR5K0Ch5CkaRGjbsHXsA/Jingz6tq1/IOSXYAOwA2b9485uYeX9bztyk1O4+331wFX9uTMu4e+DlV9QLgVcBbk7xweYeq2lVV81U1Pzc3N+bmJElLxgrwqrqvuz8MXAucNYmiJEnDrTnAk/xYkqcuTQOvAPZNqjBJ0tGNcwz8ZODaJEvr+duq+oeJVCVJGmrNAV5VdwPPn2AtkqRV8DRCSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMm8X3g0sQ8Hr+Zb734jYDtcw9ckhplgEtSowxwSWqUAS5JjTLAJalRnoUi4RkZapN74JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaNVaAJzkvyZ1JDibZOamiJEnDrTnAkxwHfAB4FXAGcHGSMyZVmCTp6MbZAz8LOFhVd1fVfwMfBS6YTFmSpGHG+TbCU4B/63t8L/Bzyzsl2QHs6B5+P8mda9zeScADa1x2mqxrdaxrdaxrdY7VushlY9X27EGN4wR4BrTVEQ1Vu4BdY2ynt7Fkoarmx13PpFnX6ljX6ljX6hyrdcF0ahvnEMq9wLP6Hj8TuG+8ciRJoxonwP8JOD3JqUmeBGwD9kymLEnSMGs+hFJVjyR5G/A54Djgyqq6Y2KVHWnswzBTYl2rY12rY12rc6zWBVOoLVVHHLaWJDXAKzElqVEGuCQ16pgK8CSvS3JHkseSrHi6zUqX8HcfqN6c5ECSj3Ufrk6irhOT7O3WuzfJxgF9XpLk1r7bfyW5sJt3VZJv9s3bOqu6un6P9m17T1/7eo7X1iRf6p7v25L8Ut+8iY7XsK98SHJ89/cf7MZjS9+8d3XtdyZ55Th1rKGu30zy9W58rk/y7L55A5/TGdX1xiSLfdv/tb5527vn/UCS7TOu6/K+mr6R5KG+edMcryuTHE6yb4X5SfKnXd23JXlB37zxxquqjpkb8DzgucBNwPwKfY4D7gJOA54EfA04o5v3cWBbN/0h4C0TquuPgZ3d9E7gsiH9TwQeBH60e3wVcNEUxmukuoDvr9C+buMF/DRwejf9DOAQcMKkx+tor5e+Pr8OfKib3gZ8rJs+o+t/PHBqt57jZljXS/peQ29Zqutoz+mM6noj8GcDlj0RuLu739hNb5xVXcv6/wa9EyumOl7dul8IvADYt8L884HP0rt25mzg5kmN1zG1B15V+6tq2JWaAy/hTxLgXOCart9u4MIJlXZBt75R13sR8Nmq+s8JbX8lq63rh9Z7vKrqG1V1oJu+DzgMzE1o+/1G+cqH/nqvAV7ajc8FwEer6uGq+iZwsFvfTOqqqhv7XkNfpnetxbSN8xUZrwT2VtWDVfVtYC9w3jrVdTFw9YS2fVRV9QV6O2wruQD4q+r5MnBCkk1MYLyOqQAf0aBL+E8BngY8VFWPLGufhJOr6hBAd//0If23ceSL5w+7t0+XJzl+xnU9OclCki8vHdbhGBqvJGfR26u6q695UuO10utlYJ9uPL5Db3xGWXaadfW7hN5e3JJBz+ks6/qF7vm5JsnSBX3HxHh1h5pOBW7oa57WeI1ipdrHHq9xLqVfkySfB35ywKz3VNV1o6xiQFsdpX3sukZdR7eeTcDP0js/fsm7gH+nF1K7gN8Gfn+GdW2uqvuSnAbckOR24LsD+q3XeP01sL2qHuua1zxegzYxoG353zmV19QQI687yeuBeeBFfc1HPKdVddeg5adQ198BV1fVw0neTO/dy7kjLjvNupZsA66pqkf72qY1XqOY2utr5gFeVS8bcxUrXcL/AL23Jhu6vahVXdp/tLqS3J9kU1Ud6gLn8FFW9YvAtVX1g751H+omH07yl8BvzbKu7hAFVXV3kpuAM4FPss7jleTHgU8Dv9O9tVxa95rHa4BRvvJhqc+9STYAP0HvLfE0vy5ipHUneRm9/xRfVFUPL7Wv8JxOIpCG1lVV/9H38MPAZX3LvnjZsjdNoKaR6uqzDXhrf8MUx2sUK9U+9ni1eAhl4CX81ftU4EZ6x58BtgOj7NGPYk+3vlHWe8Sxty7Elo47XwgM/LR6GnUl2bh0CCLJScA5wNfXe7y65+5aescGP7Fs3iTHa5SvfOiv9yLghm589gDb0jtL5VTgdOArY9SyqrqSnAn8OfCaqjrc1z7wOZ1hXZv6Hr4G2N9Nfw54RVffRuAV/N93olOtq6vtufQ+EPxSX9s0x2sUe4Bf7s5GORv4TreTMv54TeuT2bXcgNfS+1/pYeB+4HNd+zOAz/T1Ox/4Br3/Qd/T134avX9gB4FPAMdPqK6nAdcDB7r7E7v2eeAv+vptAb4FPGHZ8jcAt9MLor8BnjKruoCf77b9te7+kmNhvIDXAz8Abu27bZ3GeA16vdA7JPOabvrJ3d9/sBuP0/qWfU+33J3Aqyb8eh9W1+e7fwdL47Nn2HM6o7r+CLij2/6NwM/0Lfur3TgeBH5llnV1j98LXLpsuWmP19X0zqL6Ab38ugR4M/Dmbn7o/fjNXd325/uWHWu8vJRekhrV4iEUSRIGuCQ1ywCXpEYZ4JLUKANckhplgEtSowxwSWrU/wB/9oZOyVWQ9QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist((jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (100,))*2)*0.99999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jax.jit, static_argnums=(1, 2), backend='cpu')\n",
    "def gen_data(seed, batch, num):\n",
    "    rng = jax.random.PRNGKey(seed)\n",
    "    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)\n",
    "    qt0 = jax.random.uniform(rng+1, (batch,), minval=-0.99, maxval=0.99)\n",
    "    g = jax.random.normal(rng+2, (batch,))*10\n",
    "\n",
    "    y0 = jnp.stack([q0, qt0, g]).T.ravel()\n",
    "\n",
    "    yt = odeint(ofunc, y0, jnp.linspace(0, 1, num=num), mxsteps=300)\n",
    "\n",
    "    qall = yt[:, ::3]\n",
    "    qtall = yt[:, 1::3]\n",
    "    gall = yt[:, 2::3]\n",
    "    \n",
    "    return jnp.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T\n",
    "\n",
    "@partial(jax.jit, static_argnums=(1,))\n",
    "def gen_data_batch(seed, batch):\n",
    "    rng = jax.random.PRNGKey(seed)\n",
    "    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)\n",
    "    qt0 = (jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (batch,))*2)*0.99999)#jax.random.uniform(rng+1, (batch,), minval=-1, maxval=1)\n",
    "    g = jax.random.normal(rng+2, (batch,))*10\n",
    "    \n",
    "    return jnp.stack([q0, qt0]).reshape(2, -1).T, g.reshape(1, -1).T, jnp.stack(qdotdot(q0, qt0, g)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[ 1.4900093 ,  0.9574616 ],\n",
       "              [-8.006279  , -0.95951325],\n",
       "              [-2.1367955 , -0.21696056],\n",
       "              [ 7.883566  ,  0.6245134 ],\n",
       "              [ 1.9313316 ,  0.33272812]], dtype=float32),\n",
       " DeviceArray([[ 5.672981  ],\n",
       "              [13.351437  ],\n",
       "              [ 3.9785006 ],\n",
       "              [17.436966  ],\n",
       "              [-0.24577108]], dtype=float32),\n",
       " DeviceArray([[ 0.9574616 ,  0.00400571],\n",
       "              [-0.95951325,  0.00833027],\n",
       "              [-0.21696056,  3.2232604 ],\n",
       "              [ 0.6245134 ,  2.8466694 ],\n",
       "              [ 0.33272812, -0.15006456]], dtype=float32))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate, cconditionals, ctarget = gen_data_batch(0, 5)\n",
    "cstate, cconditionals, ctarget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# qdotdot(jnp.array([0]), jnp.array([0.9]), jnp.array([10]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0.29830917716026306 {'act': [4],\n",
    "# 'batch_size': [27.0], 'dt': [0.09609870774790222],\n",
    "# 'hidden_dim': [596.0], 'l2reg': [0.24927677946969878],\n",
    "# 'layers': [4.0], 'lr': [0.005516656601005163],\n",
    "# 'lr2': [1.897157209816416e-05], 'n_updates': [4.0]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loaded = pkl.load(open('./params_for_loss_0.29429444670677185_nupdates=1.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = ObjectView({'dataset_size': 200,\n",
    " 'fps': 10,\n",
    " 'samples': 100,\n",
    " 'num_epochs': 80000,\n",
    " 'seed': 0,\n",
    " 'loss': 'l1',\n",
    " 'act': 'softplus',\n",
    " 'hidden_dim': 500,\n",
    " 'output_dim': 1,\n",
    " 'layers': 4,\n",
    " 'n_updates': 1,\n",
    " 'lr': 0.001,\n",
    " 'lr2': 2e-05,\n",
    " 'dt': 0.1,\n",
    " 'model': 'gln',\n",
    " 'batch_size': 68,\n",
    " 'l2reg': 5.7e-07,\n",
    "})\n",
    "# args = loaded['args']\n",
    "rng = jax.random.PRNGKey(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params = None\n",
    "best_loss = np.inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_random_params, nn_forward_fn = extended_mlp(args)\n",
    "rng = jax.random.PRNGKey(0)\n",
    "_, init_params = init_random_params(rng, (-1, 3))\n",
    "rng += 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the output. Now, let's train it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Idea: add identity before inverse:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let's train it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_small_loss = np.inf\n",
    "iteration = 0\n",
    "total_epochs = 100\n",
    "minibatch_per = 3000\n",
    "train_losses, test_losses = [], []\n",
    "\n",
    "lr = 1e-3 #1e-3\n",
    "\n",
    "final_div_factor=1e4\n",
    "\n",
    "#OneCycleLR:\n",
    "@jax.jit\n",
    "def OneCycleLR(pct):\n",
    "    #Rush it:\n",
    "    start = 0.3 #0.2\n",
    "    pct = pct * (1-start) + start\n",
    "    high, low = lr, lr/final_div_factor\n",
    "    \n",
    "    scale = 1.0 - (jnp.cos(2 * jnp.pi * pct) + 1)/2\n",
    "    \n",
    "    return low + (high - low)*scale\n",
    "    \n",
    "\n",
    "opt_init, opt_update, get_params = optimizers.adam(\n",
    "    OneCycleLR\n",
    ")\n",
    "opt_state = opt_init(init_params)\n",
    "# opt_state = opt_init(best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'lr schedule')"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEICAYAAABcVE8dAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxc5X3v8c9vZrSv1mpLtiwvso0BY0BgdkiAxlAcJ21DgWwFGkLuJYWmacO9aZP0Nnlx29w2996GhppA2BIIhDQxhCahScEEMNjG2HgBb9iWLNtarMXat6d/zJGRhWRkazTnaOb7fr300syjmXN+Ohp955nnLI855xARkcQX8rsAERGJDwW+iEiSUOCLiCQJBb6ISJJQ4IuIJAkFvohIklDgy5RkZnvN7KpJXscLZvanMVrWN8zssVg/VuRkKPBFRJKEAl8SjplF/K5BJIgU+DLleUMgPzGzx8ysDfiTUR5zrZltM7OjZnbAzL487GcrzexNM2szs91mtnzYU2eb2cve835tZkXDnneBmb1iZi1mtsnMrhj2szlm9qL3vOeB4c+7wsxqR9Q35hDVidYjcjIU+JIoVgI/AfKBH47y8weAzzvncoAzgN8CmNn5wCPAX3rPvQzYO+x5NwE3AyVAKvBl73nlwC+AbwIFXvvTZlbsPe9HwAaiQf93wGdP5Zcax3pExk2BL4niVefcz5xzg865rlF+3gcsNrNc51yzc+4Nr/1W4EHn3PPecw84594e9rwfOOd2eMt8EljqtX8KeM4595z3vOeB9cC1ZlYBnAf8jXOuxzm3BnjmFH+vMddzisuTJKbAl0RR8wE//0OiIbnPG2q50GufBew+wfMODbvdCWR7t2cDn/CGWVrMrAW4BJgBlAHNzrmOYc/dN87fY6QTrUfkpGjnliSKE1721Tm3DlhpZinAHUR767OIvlHMO4X11QCPOuc+N/IHZjYbmGZmWcNCv2JYjR1A5rDHh4GxhmjGXI/IyVIPXxKemaWa2SfNLM851we0AQPejx8AbjazK80sZGblZrZoHIt9DFhhZh8xs7CZpXs7Y2c65/YRHXb5W2/dlwArhj13B5BuZr/vvQH9NZB2sus5hU0hSU6BL8ni08Be7yie24mOjeOce53oTtnvAK3Ai0SHUU7IOVdDdEfx/wQaiPbE/5L3/qduApYBR4CvE90xPPTcVuC/Ad8HDhDt8R931M5JrEdk3EwToIiIJAf1EkREkoQCX0QkSSjwRUSShAJfRCRJBPo4/KKiIldZWel3GSIiU8qGDRsanXPvO7cj0IFfWVnJ+vXr/S5DRGRKMbNRz+zWkI6ISJJQ4IuIJAkFvohIklDgi4gkCQW+iEiSiFvgm9lpZnafNxXdF+K1XhERiRpX4JvZg2ZWb2ZbRrQvN7N3zGyXmd19omU457Y7524HrgeqT71kERE5FeM9Dv8h4LsMu8SrN2nDvcDVRC/tus7MVgNh4J4Rz7/FOVdvZh8F7vaWJXEwMOjo7O2nq3eAjt6BY7e7+wbpGxykr3+Q/kFH38AgvcNu9w1Ev/cPDDJ0QVWzoe92bPnH2jDMICUcIjUSIi0cIiVipIbDpEaibSlhIy0SIjUcJj0lRGZahOzUCJlpYVLCGl0UmWzjCnzn3BozqxzRfD6wyzm3B8DMngBWOufuAa4bYzmrgdVm9guikzy/j5ndBtwGUFFRMZ7yEppzjqM9/TS199LU3kNLZx+tXe99tXV734e1dfREg72zd4Ce/kG/f4VxSY2EyEoNk5UWISs1QlZa9HZ+Zir5GSnkZ6aQl5Fy3P1oWyoFWamEQ/bBKxFJchM507ac4+cRrSU64cOozOwK4A+Izuzz3FiPc86tAlYBVFdXJ+zF+vsGBjnc1s2h1m4OtnZzuK2bhvYemtp7afS+N7X30NjRS+8JQjsnLUJuRgq5GSnkZUSoLMwiOz0ampmpYTJSw2SlRshIDZOZGibTa0871uuOfkXCRuqw29F2IxIKEbL35uYb6u07HCOnUhh0jr5+R8/AAL390U8Jvf3RTw69AwP09jt6vU8S3X3RN6X2ngE6e/rp6B2go6efjt5+Onqib1Zt3f3UNnfR0tlLa1cfg2O8GkIGBVmpFGWnUZyTRrH3/dj9nDSm56VTlpdBRmp44n88kSlqIoE/WpdqzIB2zr0AvDCB9U0prZ197D/Syf4jndQ0d3rB3nUs4Bvae94XmClhoyg7jcLsaHgtKM2hKCeVoqw0inJSKchKY5rX083LSCE7LUIkaEMhqQApMV/s4GD0k05rZx8tXb20dPbR0tVHc0f0jbGhvYeGo700tPewp6GDhvaeUd8op2WmMCMvg7L8DMry0ynLz2BGXvR7RUEmJTlpxw1ZiSSSiQR+LdFJoIfMBOomVk6Uma0AVsyfPz8Wi5sUzjnqWrvZ29jB/iOd7GvqpMYL+P1HOmnt6jvu8dlpEWbkpTM9L52F03OYnpdx7H5ZXgbTc9PJzYgobMYQCtmxN7qK9+b/HpNzjrbufhrbe6hv6+FQWxd1Ld3UtXRxsLWb2uZOXn+3ibbu/uOel54SYnZBFrMLM72vLCoLo/fL8jM0dCRT2rinOPTG8J91zp3h3Y8QnYz5SqLzcq4DbnLObY1VcdXV1c7vi6f1DQyyr6mDXfXt7KpvZ3dDh/e9nc7egWOPSwkbM6dlMqsgk4qCaG+xoiCLioJMZhZkkJse+16vTFx7Tz8HW7o40NJFjffGvbepk/1HOtjX1HncPpCUsFFRkElVSQ5VpdnML8mmqiSHucVZpKdoqEiCw8w2OOfedzTkuHr4ZvY4cAVQZGa1wNedcw+Y2R3Ar4gemfNgrMLejx7+UI99e10b2w62sf1gGzsOH2VfUyf9wwaPy/LSmVeSzfXVs5hfks3c4ixmF2YxPTddvb8pKDstQlVpDlWlOe/72eCg4/DRbvY1dbKvqYO9TZ3saWhnR/1Rnt9+mAHvdREymF2Y5b0BZLOgNIfTy3KZW5yt14QESqAnMZ+sHn5v/yC76tvZdrCNbXXRcN92sO24YZg5RVks8Hpx80uymVcc/cpKC/QVpSVOevoH2NvYyc76o+w83H7s+7uNHcc6COkpIRZNz+X0slzOKM/j9LJcFpTm6NOATLqxevgJH/jOOfY2dfJmTTObalp5s6aFbXVt9A5EP6oP/VOeNiOXxWW5LJ6Rw8LpuWQr2OUU9A0Msruhna0H2tha18bWula2HWzjqLevIBIy5pdkc3pZHksr8jmnIp+FpTnB2/kuU9qUCvxhQzqf27lz50k//7U9Tby8q5GNNS1srm091nPPTA1zZnkeS2flc0Z5HqfNyGVOUZY+dsukcs5Rc6SLrXWtbKlrZWtdG1sOtNLY3gtARkqYJTPzOLtiGmdX5HN2RT4lOek+Vy1T2ZQK/CGn2sP/yk8289SGGhZOz2XprDzOmpnP0op8qkpyFO4SCM45apu7eGN/Mxv3t7CxpoVtda30DUT/H2cVZLBsTiHL5hRwwdxCZk7L0BFcMm5JFfgNR3vISoueZCQyVXT3DbC1ro2N+5tZt/cIr797hObO6KfTsrx0ls2NvgEsm1tIZWGm3gBkTFMq8Cc6pCOSCAYHHTvr23nt3SZe23OE195tOjYMVJ6fwaVVRVy2oJiL5xWRl6nDfuU9UyrwhwThOHyRoHDOsbuhg7V7mvjdzkZe3t3I0e5+QgZLZuZz2YJiLqsqYumsfO0ETnIKfJEE0z8wyKbaFtbsaGTNzgY21bQw6CAnPcKHFpZw9eJSLl9YrJP+kpACXyTBtXb28cruRn77dj2/fbuepo5eUsLGBXMLuXpxKVeeVkp5fobfZUocTKnA1xi+yMQMDDo27m/m+W2HeX77YfY0dABwelkuy0+fznVnlTGnKMvnKmWyTKnAH6Ievkhs7G5o5z+2HebX2w6zYV8zAGeU57JiSRnXnVWmnn+CUeCLCAAHW7v4xeaDPLP5IJtqWgA4d/Y0ViyZwbVLZuikrwSgwBeR99nf1Mkzm+t4ZlMdbx86Ssjgkqpirq+eyVWnleq6P1OUAl9ETmjn4aOs3lTH0xtqqWvtJi8jhZVLy/jEubM4ozxXJ3pNIVMq8LXTVsQ/A4OOV3Y38tT6Wn659RC9/YMsmp7DTcsq+PjZ5eToMM/Am1KBP0Q9fBF/tXb18cymOp5cX8Pm2lYyU8N8/OxyPn3hbBZNz/W7PBmDAl9EJmRTTQuPrd3H6k119PQPcn5lAZ+6cDbLT59OakRn9gaJAl9EYqK5o5enNtTw2Nr97D/SyfTcdG6+uJIbl1XorN6AUOCLSEwNDjpe3NHA/S/t4ZXdTWSnRbhpWQU3X1zJjDwd1+8nBb6ITJotB1pZtWYPv3jrIAZ89Kwybr9iHgtGmStYJt+UCnwdpSMyNdUc6eTBl9/lx+tq6Oob4NozZ3DXlVWjThIvk2dKBf4Q9fBFpqaWzl6+/9K7/ODld+nsG+C6JWXceeV85pco+ONBgS8icdfc0cv9L+3hoVf20tU3wIolZdx5VRXzirP9Li2hKfBFxDdHOnpZtWYPj7y6l57+QW44bxZ3XlWl6/ZMEgW+iPiuqb2Hf/7tLh5bu4/USIjPXzaPz102R/NPx9hYga+zJUQkbgqz0/jGR0/n+S9dzuULivnOf+zg8m+/wE821DI4GNzOZ6JQ4ItI3M0pyuJ7nzqXp79wETOnZfDlpzbxR/e9wpYDrX6XltAU+CLim3NnT+Pp2y/i23+0hP1HOlnx3d/x1z97i5bOXr9LS0gKfBHxVShkfKJ6Fr/5iyv4k4sqefz1Gj78jy/ys40HCPI+xqkokIFvZivMbFVrqz7eiSSLvIwUvr7idH7xZ5cwuzCTu378Jrc8tI66li6/S0sYOkpHRAJnYNDxyKt7+YdfvkPI4CvXLOJTy2YTCmkSlvHQUToiMmWEQ8bNF8/h139+GefMnsbXfr6Vm76/lgPq7U+IAl9EAmtWQSaP3HI+//CHS3irtpXl/3cNqzfV+V3WlKXAF5FAMzOuP28Wz915KVUl2fzZ4xu564mNtHb1+V3alKPAF5EpYXZhFk9+/kL+/KoFPLP5INf+v5fYuL/Z77KmFAW+iEwZkXCIO6+q4qnbL8QMrv/XV3n01b06fHOcFPgiMuWcUzGNZ794CZfML+Jvfr6VP//xm3T29vtdVuAp8EVkSsrPTOWBz57HX1y9gJ9vquNj977MnoZ2v8sKNAW+iExZoZDxxSureOSW82k42sPKe1/mlV2NfpcVWAp8EZnyLq0q5pkvXsKMvHQ+8+DrPLmuxu+SAimugW9mWWa2wcyui+d6RSTxzZyWyU++cBEXzivkr57ezP/+97d1yeURxhX4ZvagmdWb2ZYR7cvN7B0z22Vmd49jUV8BnjyVQkVEPkhuego/+JPzuGlZBfe9uJs7Hn+D7r4Bv8sKjPFOM/MQ8F3gkaEGMwsD9wJXA7XAOjNbDYSBe0Y8/xZgCbAN0JxmIjJpIuEQ3/rYGcwtyuJbz22nsf11vv/ZanLTU/wuzXfjCnzn3BozqxzRfD6wyzm3B8DMngBWOufuAd43ZGNmHwKygMVAl5k955wbnEDtIiKjMjP+9NK5lOSm86Ufv8lN96/l4ZvPpzA7ze/SfDWRMfxyYPiekVqvbVTOua865+4CfgTcP1bYm9ltZrbezNY3NDRMoDwRSXYfPauM+z9Tzc7D7XziX1/lYGtyX3xtIoE/2nVKP3APiXPuIefcsyf4+SrnXLVzrrq4uHgC5YmIwIcWlfDorcuob+vhhlVrOdTa7XdJvplI4NcCs4bdnwnE5DJ2mgBFRGLp/DkFPHzL+TS193Lj/Ws53JacoT+RwF8HVJnZHDNLBW4AVseiKOfcM8652/Ly8mKxOBERzp09jYdvOY/6tm5uXLWW+iQM/fEelvk48Cqw0MxqzexW51w/cAfwK2A78KRzbmssilIPX0Qmw7mzoz39w23dfPqB12ntTK5LLGuKQxFJOi/vauTmH6xjycw8Hr11GRmpYb9LiilNcSgi4rl4fhHf+eOlbNjfzB0/eoO+geQ4QjyQga8hHRGZbL+/ZAb/a+UZ/Obtev7mZ1uS4pr6gQx87bQVkXj49AWzueND83liXQ0P/O5dv8uZdOO9tIKISEL60tUL2N3Qzree286coiyuPK3U75ImTSB7+CIi8RIKGf94/VmcXpbLnz2+kbcPtfld0qQJZOBrDF9E4ikzNcL3P3MeWWkRbn90A23diXm4ZiADX2P4IhJv0/PSufeT51DT3MVfPbU5IXfiBjLwRUT8cF5lAXcvX8Qvtx7iwZf3+l1OzCnwRUSG+dNL5/B7i0u557ntbNjX7Hc5MRXIwNcYvoj4xcz49ifOoiw/gzuf2Eh7T7/fJcVMIANfY/gi4qe8jBT+6fqzqGvp4pvPbvO7nJgJZOCLiPiturKAz18+jyfW1fCb7Yf9LicmFPgiImO466oqFk3P4StPv8WRjl6/y5mwQAa+xvBFJAjSImG+88dLae3q5Ws/3+J3ORMWyMDXGL6IBMVpM3L54oereHbzQV7cMbXn2Q5k4IuIBMnnL5/L3KIsvvbzLXT3DfhdzilT4IuIfIC0SJhvfuwM9jV18i8v7Pa7nFOmwBcRGYeL5hexcmkZ972wmz0N7X6Xc0oU+CIi4/TV3z+NtJQQX18dk+m7406BLyIyTiU56dx5ZRUv7WzkpZ1TbwduIANfh2WKSFB9+sLZlOdn8Pe/fJvBwal1Rc1ABr4OyxSRoEqLhPnyRxaw5UAbz2yu87uckxLIwBcRCbKVZ5Vz2oxc/s+v36G3f9DvcsZNgS8icpJCIePuaxZRc6SLH762z+9yxk2BLyJyCi6rKuKieYXc+5+7pszJWAp8EZFTYGZ88cNVNLb38pMNtX6XMy4KfBGRU3TB3AKWzspn1Zo99A8EfyxfgS8icorMjNsvn8f+I538+5ZDfpfzgQIZ+DoOX0Smit9bXMrc4iy+98JunAv2cfmBDHwdhy8iU0UoFO3lbzvYxpqdjX6Xc0KBDHwRkankY0vLmZ6bzvde2OV3KSekwBcRmaDUSIibL65k7Z4jvHPoqN/ljEmBLyISA9dXzyI1EuKxtcE9EUuBLyISA9OyUlmxpIyfvlFLe0+/3+WMSoEvIhIjn75wNh29A/xs4wG/SxmVAl9EJEbOmpnHouk5PLW+xu9SRqXAFxGJETPjj86dyabaVnYcDt7OWwW+iEgMffzsciIhC+T1dRT4IiIxVJidxocXlfDTNw4E7vo6CnwRkRj7+NnlNLb38Nq7R/wu5ThxC3wzu8LMXjKz+8zsinitV0Qk3j60qISs1DDPbArWFIjjCnwze9DM6s1sy4j25Wb2jpntMrO7P2AxDmgH0oHgDW6JiMRIekqYqxeX8suthwI1BeJ4e/gPAcuHN5hZGLgXuAZYDNxoZovN7Ewze3bEVwnwknPuGuArwN/G7lcQEQme65aU0dLZx+92NfhdyjGR8TzIObfGzCpHNJ8P7HLO7QEwsyeAlc65e4DrTrC4ZiBtrB+a2W3AbQAVFRXjKU9EJHAuXVBETlqEX245xIcXlfpdDjCxMfxyYPjZBbVe26jM7A/M7F+BR4HvjvU459wq51y1c666uLh4AuWJiPgnLRLm8oXF/GZ7PQODwbhO/kQC30ZpG/O3cs791Dn3eefcHzvnXjjhgjUBiogkgKsXl9LU0cubNc1+lwJMLPBrgVnD7s8EYrJLWhOgiEgiuGJhCZGQ8fy2er9LASYW+OuAKjObY2apwA3A6tiUJSIy9eVlpLBsbgHPbwvGfLfjPSzzceBVYKGZ1ZrZrc65fuAO4FfAduBJ59zWWBSlIR0RSRRXLipld0MHNUc6/S5lfIHvnLvROTfDOZfinJvpnHvAa3/OObfAOTfPOfetWBWlIR0RSRSXLSgC4KUAzHerSyuIiEyiecXZzMhLD8Tx+IEMfA3piEiiMDMumV/Ey7uafD88M5CBryEdEUkkly4oprWrj7cO+NuJDWTgi4gkkovnFWIGv9vp77BOIANfQzoikkgKs9NYWJrj++WSAxn4GtIRkURzXmUBb+xr9nVSlEAGvohIoqmunEZH7wBvH/JvrlsFvohIHJxXWQDAur3+DesEMvA1hi8iiaYsP4Py/AzW7/XvQmqBDHyN4YtIIjqvchqv7z2Cc/4cjx/IwBcRSUTnVhbQcLSH2uYuX9avwBcRiZMl5dFRiy0+nYClwBcRiZNFM3JICRubFfjv0U5bEUlEaZEwC0pz1MMfTjttRSRRLZmZx+baVl923AYy8EVEEtUZ5Xm0dvX5suNWgS8iEkdnejtuN9fGf1hHgS8iEkcLpw/tuG2J+7oV+CIicZQWCTOvOJsdPlxTJ5CBr6N0RCSRLSjNYcfh9rivN5CBr6N0RCSRVZVkc6Cli46e/riuN5CBLyKSyKpKcwDYVR/fXr4CX0QkzhaUZgOw43B8x/EV+CIicVZRkElqOKQevohIoouEQ8wtzlIPX0QkGfhxpI4CX0TEB34cqaPAFxHxwbyS6I7bvU0dcVtnIANfJ16JSKKrKMgEoOZIZ9zWGcjA14lXIpLoZnmBvz/ZA19EJNHlZaSQl5GiwBcRSQYVBZnsPxK/6+Ir8EVEfFJRkEmtevgiIolvVkEmtc1dDAzGZ7pDBb6IiE8qCjLpHRjkcFt3XNanwBcR8UlFnI/UUeCLiPhEgS8ikiRm5KcTDlncTr5S4IuI+CQlHGJGXnrceviRuKwFMLMQ8HdALrDeOfdwvNYtIhJUM/LSg7XT1sweNLN6M9syon25mb1jZrvM7O4PWMxKoBzoA2pPrVwRkcRSkptO/dGeuKxrvEM6DwHLhzeYWRi4F7gGWAzcaGaLzexMM3t2xFcJsBB41Tn3JeALsfsVRESmrpKcNOrb4hP44xrScc6tMbPKEc3nA7ucc3sAzOwJYKVz7h7gupHLMLNaoNe7OzDWuszsNuA2gIqKivGUJyIyZZXmptPe009HTz9ZaZM7yj6RnbblQM2w+7Ve21h+CnzEzP4ZWDPWg5xzq5xz1c656uLi4gmUJyISfKW5aQBxGdaZyNuJjdI25vnBzrlO4NYJrE9EJOGU5KQDcLitmzlFWZO6ron08GuBWcPuzwTqJlZOlCZAEZFkMdTDj8eROhMJ/HVAlZnNMbNU4AZgdSyK0gQoIpIsir0efkMchnTGe1jm48CrwEIzqzWzW51z/cAdwK+A7cCTzrmtsShKPXwRSRa56RHSU0Jx6eGP9yidG8dofw54LqYVRZf7DPBMdXX152K9bBGRIDEzSuN0LL4urSAi4rOSnLTAj+FPGg3piEgyKclNj8vJV4EMfO20FZFkUpqjIR0RkaRQkptGe08/7T39k7qeQAa+hnREJJkcO9t2ksfxAxn4GtIRkWRSmBUN/CMdvR/wyIkJZOCLiCST/MwUAFo6+yZ1PQp8ERGfTctMBaC5Mwl7+BrDF5FkktQ9fI3hi0gyyU6LEAlZcvbwRUSSiZmRn5lCS1cS9vBFRJJNfmYqLerhi4gkvvyMFJo7krCHr522IpJs8jNTk3MMXzttRSTZTMtMoVVj+CIiiW9aVpL28EVEkk1eRgrdfYN09w1M2joU+CIiARCPs20V+CIiATAtDmfbBjLwdZSOiCSbPC/wk66Hr6N0RCTZDA3pJF0PX0Qk2WgMX0QkScTjipkKfBGRAEhPCZOeEprU6+ko8EVEAmJaZirN6uGLiCS+yb5ipgJfRCQgCrNSaWxX4IuIJLyi7FSaOnombfmBDHydeCUiyagoO43Go0nWw9eJVyKSjIpy0ujqG6Cjp39Slh/IwBcRSUaFWdGTrxrbJ2dYR4EvIhIQRTlpgAJfRCThFWdHA79hksbxFfgiIgFR5AX+ZB2po8AXEQmIwmxvDF89fBGRxJYSDpGfmaIxfBGRZFCUnabAFxFJBkXZqQp8EZFkUJidNmnX01Hgi4gESPEkDulEJmWpozCzS4FPeutc7Jy7KF7rFhGZKoqyUzna3U933wDpKeGYLntcPXwze9DM6s1sy4j25Wb2jpntMrO7T7QM59xLzrnbgWeBh0+9ZBGRxPXesfixH9YZ75DOQ8Dy4Q1mFgbuBa4BFgM3mtliMzvTzJ4d8VUy7Kk3AY/HoHYRkYQzFPiNR2M/rDOuIR3n3BozqxzRfD6wyzm3B8DMngBWOufuAa4bbTlmVgG0OufaxlqXmd0G3AZQUVExnvJERBJGZVEm1545PebDOTCxnbblQM2w+7Ve24ncCvzgRA9wzq1yzlU756qLi4snUJ6IyNQzvySHf/nkuSycnhPzZU9kp62N0uZO9ATn3NfHtWCzFcCK+fPnn0pdIiIyion08GuBWcPuzwTqJlZOlCZAERGJvYkE/jqgyszmmFkqcAOwOjZliYhIrI33sMzHgVeBhWZWa2a3Ouf6gTuAXwHbgSedc1tjUZTmtBURiT1z7oTD7r6qrq5269ev97sMEZEpxcw2OOeqR7br0goiIkkikIGvIR0RkdgLZODrKB0RkdgL9Bi+mTUA+07x6UVAYwzLiZWg1gXBrU11nZyg1gXBrS3R6prtnHvfmauBDvyJMLP1o+208FtQ64Lg1qa6Tk5Q64Lg1pYsdQVySEdERGJPgS8ikiQSOfBX+V3AGIJaFwS3NtV1coJaFwS3tqSoK2HH8EVE5HiJ3MMXEZFhFPgiIkkiIQP/ZObaneQ6ZpnZf5rZdjPbamZ3eu3fMLMDZvam93WtD7XtNbO3vPWv99oKzOx5M9vpfZ8W55oWDtsmb5pZm5nd5df2Gm0u57G2kUX9f+81t9nMzolzXd82s7e9df+bmeV77ZVm1jVs290X57rG/NuZ2f/wttc7ZvaRONf142E17TWzN732eG6vsfJh8l5jzrmE+gLCwG5gLpAKbAIW+1TLDOAc73YOsIPo/L/fAL7s83baCxSNaPsH4G7v9t3A3/v8dzwEzPZrewGXAecAWz5oGwHXAv9OdGKgC4DX4lzX7wER7/bfD6urcvjjfNheo/7tvP+DTUAaMMf7nw3Hq64RP/9H4Gs+bK+x8mHSXmOJ2MM/Nteuc64XeAJY6UchzrmDzrk3vNtHiV5G+oOmgfTTSqokAsMAAAMDSURBVOBh7/bDwMd8rOVKYLdz7lTPtJ4w59wa4MiI5rG20UrgERe1Fsg3sxnxqss592sXvWQ5wFqiExLF1RjbaywrgSeccz3OuXeBXUT/d+Nal5kZcD3w+GSs+0ROkA+T9hpLxMA/lbl2J51FJ4E/G3jNa7rD+1j2YLyHTjwO+LWZbbDoxPEApc65gxB9MQIlPtQ15AaO/yf0e3sNGWsbBel1dwvRnuCQOWa20cxeNLNLfahntL9dULbXpcBh59zOYW1x314j8mHSXmOJGPgnPdfuZDOzbOBp4C7nXBvwPWAesBQ4SPQjZbxd7Jw7B7gG+O9mdpkPNYzKojOofRR4ymsKwvb6IIF43ZnZV4F+4Ide00Ggwjl3NvAl4EdmlhvHksb62wViewE3cnzHIu7ba5R8GPOho7Sd1DZLxMCftLl2T4WZpRD9Y/7QOfdTAOfcYefcgHNuELifSfooeyLOuTrvez3wb14Nh4c+Inrf6+Ndl+ca4A3n3GGvRt+31zBjbSPfX3dm9lngOuCTzhv09YZMmrzbG4iOlS+IV00n+NsFYXtFgD8AfjzUFu/tNVo+MImvsUQM/MDMteuNDz4AbHfO/dOw9uHjbh8Htox87iTXlWVmOUO3ie7w20J0O33We9hngZ/Hs65hjut1+b29RhhrG60GPuMdSXEB0Dr0sTwezGw58BXgo865zmHtxWYW9m7PBaqAPXGsa6y/3WrgBjNLM7M5Xl2vx6suz1XA28652qGGeG6vsfKByXyNxWNvdLy/iO7N3kH03fmrPtZxCdGPXJuBN72va4FHgbe89tXAjDjXNZfoERKbgK1D2wgoBH4D7PS+F/iwzTKBJiBvWJsv24vom85BoI9o7+rWsbYR0Y/b93qvubeA6jjXtYvo+O7Q6+w+77F/6P2NNwFvACviXNeYfzvgq972ege4Jp51ee0PAbePeGw8t9dY+TBprzFdWkFEJEkk4pCOiIiMQoEvIpIkFPgiIklCgS8ikiQU+CIiSUKBLyKSJBT4IiJJ4r8ArUnT0Hilux0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(OneCycleLR(jnp.linspace(0, 1, num=200)))\n",
    "plt.yscale('log')\n",
    "plt.title('lr schedule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_flatten\n",
    "\n",
    "@jax.jit\n",
    "def loss(params, cstate, cconditionals, ctarget):\n",
    "    runner = jax.vmap(\n",
    "        partial(\n",
    "            hamiltonian_eom,\n",
    "            learned_dynamics(params, nn_forward_fn)), (0, 0), 0)\n",
    "    preds = runner(cstate, cconditionals)[:, [0, 1]]\n",
    "    \n",
    "    error = jnp.abs(preds - ctarget)\n",
    "    #Weight additionally by proximity to c!\n",
    "    error_weights = (1 + 1/jnp.sqrt(1.0-cstate[:, [1]]**2))\n",
    "    \n",
    "    return jnp.sum(error * error_weights)*len(preds)/jnp.sum(error_weights)\n",
    "\n",
    "@jax.jit\n",
    "def update_derivative(i, opt_state, cstate, cconditionals, ctarget):\n",
    "    params = get_params(opt_state)\n",
    "    param_update = jax.grad(\n",
    "            lambda *args: loss(*args)/len(cstate),\n",
    "            0\n",
    "        )(params, cstate, cconditionals, ctarget)\n",
    "    params = get_params(opt_state)\n",
    "    return opt_update(i, param_update, opt_state), params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "cstate, cconditionals, ctarget = gen_data_batch(epoch, 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(177.31618, dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss(get_params(opt_state), cstate, cconditionals, ctarget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "update_derivative(0, opt_state, cstate, cconditionals, ctarget);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(128, 2)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen_data_batch(0, 128)[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[ -7.3825674 ],\n",
       "             [  0.48637673],\n",
       "             [  5.567146  ],\n",
       "             [  3.391427  ],\n",
       "             [-11.220732  ]], dtype=float32)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cconditionals[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[-8.077726  ,  0.13323684],\n",
       "             [-2.589147  , -0.12181558],\n",
       "             [-8.912823  ,  0.9635672 ],\n",
       "             [ 4.5006466 ,  0.9999864 ],\n",
       "             [ 6.4032707 ,  0.98424286]], dtype=float32)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[ 1.3323684e-01, -6.8172374e+00],\n",
       "             [-1.2181558e-01,  4.5502922e-01],\n",
       "             [ 9.6356720e-01,  2.6673502e-03],\n",
       "             [ 9.9998641e-01,  4.3539304e-12],\n",
       "             [ 9.8424286e-01, -6.6027965e-04]], dtype=float32)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctarget[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_loss = np.inf\n",
    "best_params = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy as copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a87954721ba4f44850f9d22a1103d46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0 lr=0.0006752928602509201 loss=0.0889248475432396\n",
      "epoch=1 lr=0.0006957106525078416 loss=0.018542585894465446\n",
      "epoch=2 lr=0.0007157498621381819 loss=0.019288646057248116\n",
      "epoch=3 lr=0.0007353719556704164 loss=0.015095868147909641\n",
      "epoch=4 lr=0.0007545390981249511 loss=0.014482150785624981\n",
      "epoch=5 lr=0.000773213803768158 loss=0.013594602234661579\n",
      "epoch=6 lr=0.0007913602748885751 loss=0.01279853843152523\n",
      "epoch=7 lr=0.0008089432376436889 loss=0.01211627572774887\n",
      "epoch=8 lr=0.0008259288733825088 loss=0.011687271296977997\n",
      "epoch=9 lr=0.0008422840619459748 loss=0.011313271708786488\n",
      "epoch=10 lr=0.0008579773711971939 loss=0.010813306085765362\n",
      "epoch=11 lr=0.000872978416737169 loss=0.010584156960248947\n",
      "epoch=12 lr=0.0008872582111507654 loss=0.009968992322683334\n",
      "epoch=13 lr=0.0009007891057990491 loss=0.009719228371977806\n",
      "epoch=14 lr=0.0009135449072346091 loss=0.009495005011558533\n",
      "epoch=15 lr=0.0009255009354092181 loss=0.009139549918472767\n",
      "epoch=16 lr=0.0009366340818814933 loss=0.008803507313132286\n",
      "epoch=17 lr=0.000946922751609236 loss=0.00873634498566389\n",
      "epoch=18 lr=0.0009563472121953964 loss=0.008653072640299797\n",
      "epoch=19 lr=0.0009648891282267869 loss=0.008492065593600273\n",
      "epoch=20 lr=0.0009725319687277079 loss=0.008285383693873882\n",
      "epoch=21 lr=0.0009792608907446265 loss=0.00816423911601305\n",
      "epoch=22 lr=0.000985063030384481 loss=0.007690694183111191\n",
      "epoch=23 lr=0.0009899272117763758 loss=0.007603948935866356\n",
      "epoch=24 lr=0.0009938437724485993 loss=0.007507197558879852\n",
      "epoch=25 lr=0.000996805145405233 loss=0.007343694102019072\n",
      "epoch=26 lr=0.0009988059755414724 loss=0.007134806364774704\n",
      "epoch=27 lr=0.0009998419554904103 loss=0.007159603759646416\n",
      "epoch=28 lr=0.0009999113390222192 loss=0.00695855962112546\n",
      "epoch=29 lr=0.0009990138933062553 loss=0.006894073914736509\n",
      "epoch=30 lr=0.000997151480987668 loss=0.006520343013107777\n",
      "epoch=31 lr=0.000994327594526112 loss=0.00651488546282053\n",
      "epoch=32 lr=0.0009905475890263915 loss=0.006448198109865189\n",
      "epoch=33 lr=0.000985819031484425 loss=0.006177311297506094\n",
      "epoch=34 lr=0.0009801508858799934 loss=0.006048914976418018\n",
      "epoch=35 lr=0.0009735542116686702 loss=0.006114535499364138\n",
      "epoch=36 lr=0.0009660416981205344 loss=0.006011721212416887\n",
      "epoch=37 lr=0.0009576278389431536 loss=0.00571891525760293\n",
      "epoch=38 lr=0.0009483289322815835 loss=0.005739100277423859\n",
      "epoch=39 lr=0.0009381631389260292 loss=0.0055201491340994835\n",
      "epoch=40 lr=0.0009271499002352357 loss=0.00546996621415019\n",
      "epoch=41 lr=0.0009153104037977755 loss=0.005301312543451786\n",
      "epoch=42 lr=0.000902667990885675 loss=0.005209599621593952\n",
      "epoch=43 lr=0.000889246875885874 loss=0.005011134315282106\n",
      "epoch=44 lr=0.0008750728447921574 loss=0.004916655831038952\n",
      "epoch=45 lr=0.0008601734298281372 loss=0.004685334395617247\n",
      "epoch=46 lr=0.00084457773482427 loss=0.004444870166480541\n",
      "epoch=47 lr=0.0008283156785182655 loss=0.0043732775375247\n",
      "epoch=48 lr=0.0008114183438010514 loss=0.0042518614791333675\n",
      "epoch=49 lr=0.0007939189672470093 loss=0.004336335230618715\n",
      "epoch=50 lr=0.0007758514257147908 loss=0.004135205876082182\n",
      "epoch=51 lr=0.0007572502945549786 loss=0.003962153568863869\n",
      "epoch=52 lr=0.000738151662517339 loss=0.003910616971552372\n",
      "epoch=53 lr=0.0007185924332588911 loss=0.003692931029945612\n",
      "epoch=54 lr=0.0006986107910051942 loss=0.0036367292050272226\n",
      "epoch=55 lr=0.0006782450363971293 loss=0.003404016373679042\n",
      "epoch=56 lr=0.0006575344013981521 loss=0.003396084066480398\n",
      "epoch=57 lr=0.0006365194567479193 loss=0.0032212536316365004\n",
      "epoch=58 lr=0.0006152403657324612 loss=0.0031083771027624607\n",
      "epoch=59 lr=0.0005937386304140091 loss=0.0029411588329821825\n",
      "epoch=60 lr=0.0005720554618164897 loss=0.002856674836948514\n",
      "epoch=61 lr=0.0005502330604940653 loss=0.0027525722980499268\n",
      "epoch=62 lr=0.0005283138016238809 loss=0.0026267138309776783\n",
      "epoch=63 lr=0.000506339652929455 loss=0.0024938301648944616\n",
      "epoch=64 lr=0.00048435357166454196 loss=0.0023678250145167112\n",
      "epoch=65 lr=0.0004623976710718125 loss=0.002306531183421612\n",
      "epoch=66 lr=0.00044051450095139444 loss=0.002116507850587368\n",
      "epoch=67 lr=0.0004187468148302287 loss=0.0020464735571295023\n",
      "epoch=68 lr=0.0003971360856667161 loss=0.0019454826833680272\n",
      "epoch=69 lr=0.00037572436849586666 loss=0.0017854681937023997\n",
      "epoch=70 lr=0.00035455342731438577 loss=0.001730295829474926\n",
      "epoch=71 lr=0.0003336636000312865 loss=0.0016554018948227167\n",
      "epoch=72 lr=0.00031309586483985186 loss=0.0015270781004801393\n",
      "epoch=73 lr=0.0002928894537035376 loss=0.0013912373688071966\n",
      "epoch=74 lr=0.0002730839769355953 loss=0.0013528444105759263\n",
      "epoch=75 lr=0.0002537174441386014 loss=0.001237965770997107\n",
      "epoch=76 lr=0.00023482716642320156 loss=0.0011553441872820258\n",
      "epoch=77 lr=0.00021644984371960163 loss=0.0010686635505408049\n",
      "epoch=78 lr=0.0001986212591873482 loss=0.0009855644311755896\n",
      "epoch=79 lr=0.00018137549341190606 loss=0.00090043805539608\n",
      "epoch=80 lr=0.0001647462631808594 loss=0.0008569382480345666\n",
      "epoch=81 lr=0.0001487653498770669 loss=0.0007737468695268035\n",
      "epoch=82 lr=0.00013346386549528688 loss=0.0007141795940697193\n",
      "epoch=83 lr=0.00011887160508194938 loss=0.0006533203995786607\n",
      "epoch=84 lr=0.00010501645738258958 loss=0.0005970757920295\n",
      "epoch=85 lr=9.192524885293096e-05 loss=0.0005417960346676409\n",
      "epoch=86 lr=7.96236636233516e-05 loss=0.0004993216134607792\n",
      "epoch=87 lr=6.813512300141156e-05 loss=0.0004574138729367405\n",
      "epoch=88 lr=5.748197509092279e-05 loss=0.0004249828343745321\n",
      "epoch=89 lr=4.768490543938242e-05 loss=0.00038963695988059044\n",
      "epoch=90 lr=3.876250411849469e-05 loss=0.00037066458025947213\n",
      "epoch=91 lr=3.073247353313491e-05 loss=0.0003445000620558858\n",
      "epoch=92 lr=2.3610193238710053e-05 loss=0.0003196481557097286\n",
      "epoch=93 lr=1.740936750138644e-05 loss=0.000304781278828159\n",
      "epoch=94 lr=1.2142033483542036e-05 loss=0.0002857021172530949\n",
      "epoch=95 lr=7.818327503628097e-06 loss=0.00027270265854895115\n",
      "epoch=96 lr=4.446651473699603e-06 loss=0.00026067483122460544\n",
      "epoch=97 lr=2.0336199213488726e-06 loss=0.0002509492333047092\n",
      "epoch=98 lr=5.838221568410518e-07 loss=0.00024522366584278643\n",
      "epoch=99 lr=1.0000000116860974e-07 loss=0.00024028531333897263\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(epoch, total_epochs)):\n",
    "    epoch_loss = 0.0\n",
    "    num_samples = 0\n",
    "    batch = 512\n",
    "    ocstate, occonditionals, octarget = gen_data_batch(epoch, minibatch_per*batch)\n",
    "    for minibatch in range(minibatch_per):\n",
    "        fraction = (epoch + minibatch/minibatch_per)/total_epochs\n",
    "        s = np.s_[minibatch*batch:(minibatch+1)*batch]\n",
    "        \n",
    "        cstate, cconditionals, ctarget = ocstate[s], occonditionals[s], octarget[s]\n",
    "        opt_state, params = update_derivative(fraction, opt_state, cstate, cconditionals, ctarget);\n",
    "        rng += 10\n",
    "        \n",
    "        cur_loss = loss(params, cstate, cconditionals, ctarget)\n",
    "        \n",
    "        epoch_loss += cur_loss\n",
    "        num_samples += len(cstate)\n",
    "    closs = epoch_loss/num_samples\n",
    "    print('epoch={} lr={} loss={}'.format(\n",
    "        epoch, OneCycleLR(fraction), closs)\n",
    "         )\n",
    "    if closs < best_loss:\n",
    "        best_loss = closs\n",
    "        best_params = [[copy(jax.device_get(l2)) for l2 in l1] if len(l1) > 0 else () for l1 in params]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pkl.dump({'params': best_params, 'description': 'q and g are divided by 10. hidden=500. act=Softplus'},\n",
    "#          open('best_sr_params_hamiltonian_true_canonical.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params = (\n",
    "    pkl.load(open('best_sr_params_hamiltonian_true_canonical.pkl', 'rb'))['params']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = opt_init(best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "cstate, cconditionals, ctarget = gen_data(0, 1, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 2)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x2aab7e8c4ed0>]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAeAElEQVR4nO3deXRc5Znn8e+jfV8syZYteQUDXgADGrOFhCUQQ9KQhUwgdBISgtMnIUtP0n0gJ02nSeck3WdOkzkZMjME0qEzaQhNEuJm6GYJ0Ek6BixjMFhe8IJtLZZkLaWtqrQ980eVTSFkq2xLLtet3+ecOlX31nXpuXbp5/e8933fa+6OiIikv6xUFyAiItNDgS4iEhAKdBGRgFCgi4gEhAJdRCQgclL1g6urq33RokWp+vEiImlp48aNB929ZrL3UhboixYtorGxMVU/XkQkLZnZ3iO9py4XEZGAUKCLiASEAl1EJCAU6CIiAaFAFxEJCAW6iEhAKNBFRAIiZePQRUSCxN2Jjo4zEB1lIDLKQHSUwWjsOfZ67PD2lWfN5tz5FdNegwJdRDJedHSM/sho/DHCQGSUvkOvo7H9ic8DkZF37BscjoX46Hhy95eoKc1XoIuITOTuREbG6YuM0BceoS8yQig8Ql94NGFfLJwT9/XHQ7svMsLw6PiUPyc/J4vSglxKC3IoyY895s8qojQ/h5L4vuL8nMPvF+fnUBp/Ls4/tC+b4rwcsrJsRv4uFOgickoYGRund2iEUHg4/jzy9nM4FsK9Q8OEwiMJj1H6wiMMjx09kAtysygryKWsMBbI5UV5sTAuyKWsIBbCh94rzc+lJL6vND+2rzg/h7ycU/+SowJdRKaVuzM4PEbP4DA9Q8N0D8YCOvY8TM/QCD1DsX294WF6BmPhPBAdPeJnmkFpfg7lRbmUF8YeteUFlBfmUV6YS1lhTuy5IDe+HQvqQyGdn5N9Ev8GUkeBLiJHNT7u9IZH6B6McnBgmK6BYboHo3QNDtMzOEzXYCy0u+MB3jN45BazGZQX5lJZlEdFUS6zSws4Y3YpFfHtinhgVxTlUVH49nZpQS7ZM9RNESQKdJEMNDw6zsGBKJ39UboGoxzsH6ZzIErXwDAHB6IcHIjSPTjMwYFYSI8d4WJfWUEOVSX5zCrOo76yiHPrK6gszmNWcSy0ZxXnUVGUR2VRbLusUME8kxToIgHh7vSFR+noj9DRH6WjP0Jnf5SOvigd/dHDAd45EKV3aGTSzyjJz6G6JI+qknwWzCrivAWVse3iPGaV5FNVnEdVSSyoK4vyyM0+9fuVM4kCXSQNREfHaA9FaQuFOdAXob0vQntflPa+CB190cP7opOM1ijMzaamNJ/ZpfmcPruEi0+roqYkn+rSfKpL8qkpjQV1TWk+BbmZ0dccVAp0kRQbHh3nQChCayhMWyhMa2+EtlCYtt4IbaFYUHcNDr/rzxXmZlNbXsDs0nxWza84/DoW3gXMLouFeEl+Dmbq5sgECnSRGTYYHaW5J0xzzxAtvWFaesI094Zpjb/uHIjiE7qoywtzmVtewLyKQlYtqKC2rIDa8gLmlhdQW1bAnPICShXUMoECXeQEjYyN09ITZl/3EHu7h2juHmJ/zxDNPWH2dw/RM6G/Oi87i3kVBdRVFvK+M2qYV1FIXUUhcysKmFteyLyKAory9Kspxy6pb42ZrQH+B5ANPODu35/w/kLgJ0AN0A38qbs3T3OtIikTGRljX/cQew4OsrdrkD0Hh9jXPci+7iFaeyPvGAWSl51FfWUhdZWFrDx7LvWVhcyvLDq8r7o4f8ZmCkpmmzLQzSwbuA+4GmgGNpjZOndvSjjsvwP/5O4PmdmVwPeAT81EwSIzZWzcaekJs+vgALs7B9ndGXve2zVIayjyjmMri3JZWFXM+Qsq+fCqIhbMij0WVhUzu1SBLamRTAt9NbDT3XcDmNkjwA1AYqAvB/48/vp54PHpLFJkOkVGxtjdOcibHf3s6hjgzY4BdnUO8NbBoXdMiCkvzGVxdTEXLaliYVUxi6qLWFxdzMKqYsoLc1N4BiKTSybQ64D9CdvNwIUTjnkN+BixbpmPAKVmVuXuXYkHmdlaYC3AggULjrdmkaSMjI2z5+Ag2w70s+NAP9sO9PNmRz/7uocOX4TMMlhYVcxpNSVccdZsllQXs6SmhCXVxcwqztNFR0kryQT6ZN/oidPGvgH8TzO7Ffgd0AK8a2EGd78fuB+goaEhuXUmRZJwcCBKU2sfW9v6aGrrY/uBfnZ1DjAyFvuaZWcZS6qLWTmvnA+vqmPpnBJOn13C4urijFnnQ4IvmUBvBuYnbNcDrYkHuHsr8FEAMysBPubuoekqUuQQd6e5J8zrLSFebwkdDvGO/ujhY+aVF3DW3DKuOGs2Z84p5czaUpbUKLgl+JIJ9A3AUjNbTKzlfRPwycQDzKwa6Hb3ceAuYiNeRE5YWyjMq/t62dwS4o14iB+atp6TZSydU8p7llazfG4Zy+eVsay2jMrivBRXLZIaUwa6u4+a2R3AU8SGLf7E3beY2T1Ao7uvAy4HvmdmTqzL5UszWLME1NDwKJubQ7y6v5dN+3p4dX8v7X2xlndOlnFmbSlrVtSysq6cs+vKObO2VFPVRRKYT5yidpI0NDR4Y2NjSn62nBo6+iI07u3h5T3dNO7tZmtb/+Hx3AurijhvfgWr5lewakElZym8RQAws43u3jDZe5qOJifN/u4h1u/q4qV4gO/tGgJid5M5b34lX7z8NM5bUMG59RVUleSnuFqR9KNAlxnT2htm/a4u1u/uYv2uLlp6wwBUFefRsKiST120kIZFs1gxr0zLsIpMAwW6TJuh4VFe3N3F73Yc5D92dLLn4CAAFUW5XLS4irXvXcIlp1Vx+uwSje8WmQEKdDlu7s729n7+Y3snv3uzkw17ehgeG6cgN4uLl1Rxy4ULuOS0as6qLdVUeJGTQIEux2RkbJwNe7p5uqmdZ7e209wT60Y5c04pt166iPcuraFhUaUuYIqkgAJdpjQYHeW5bR08u7Wd57d10BcZJT8ni/ecXs2XrjidK86cTW15QarLFMl4CnSZVHh4jOe2dfDE5lae29ZBdHScquI8PrCilquXz+E9S6u1ZrfIKUa/kXJYZGSMF7Z38sTmVn67tYPwyBjVJfnc9F/mc93Zc2lYNEt3bBc5hSnQM5y7s7k5xGMbm1n3Wiuh8AhVxXl89Pw6PnTOPFYvVoiLpAsFeobq6I/w+KYWHtvYzI72AfJzslizspaPnV/PJadVkaNx4SJpR4GeQdyd/9zZxUPr3+K5bR2MjTsXLKzkex89mw+eM5eyAt20QSSdKdAzQH9khF9ubOafXtzL7s5BqorzuP2yJXy8oZ7TakpSXZ6ITBMFeoDt7OjnoT/u5VevNDM4PMaq+RXc+4lzue7suVobXCSAFOgB9Nr+Xu57fidPN7WTl5PF9efO49MXL+Sc+opUlyYiM0iBHhDuzvrdXfzo+V38YedBygtz+epVS/n0xQu1cqFIhlCgpzl357dbO7jvhZ1s2tdLTWk+37zuLD554UJK8vXPK5JJ9Bufxja81c33ntzKK/t6qa8s5G8/vJIbL6jXOioiGUqBnoZ2dvTz/X/bzrNb25lTls/3P3o2N15Qr7HjIhlOgZ5G2vsi3PvMDh5t3E9xXg5/8YEz+dyliynMU4tcRBToaWF4dJwf/343P3zuTcbGnVsvWcwdV57OLN3dXkQSKNBPcRve6uabv3qdNzsGuHZlLXddu4wFVUWpLktETkEK9FNUaGiE7//7Nh5+eR91FYU8+JkGrlo2J9VlicgpTIF+inF31r3WyneeaKJnaIS1713CV69aSrGGIIrIFJQSp5CewWH+8pebeaapnXPry3noc6tZMa881WWJSJpQoJ8iXtzdxdceeZWuwSjf+uAyPnvpYq1DLiLHRIGeYqNj4/zwuZ388Lk3WVhVzK8/cykr69QqF5Fjp0BPodbeMF975FVefqubj55fxz03rNR0fRE5bkqPFHm2qZ2v/8trjI6Nc+8nzuUj59WnuiQRSXMK9JPM3XnwD3v47pNbWTGvjB/efD6Lq4tTXZaIBEBSi3+Y2Roz225mO83szkneX2Bmz5vZJjPbbGbXTX+p6W9s3Pmbf23ib//fVj6wvJbH/uwShbmITJspW+hmlg3cB1wNNAMbzGyduzclHPYt4FF3/19mthx4Elg0A/WmraHhUb7y8Cae3drB7Zct5q5rl5GlUSwiMo2S6XJZDex0990AZvYIcAOQGOgOlMVflwOt01lkuuvoj3DbTxvZ0hrinhtW8OmLF6W6JBEJoGQCvQ7Yn7DdDFw44ZhvA0+b2ZeBYuD9k32Qma0F1gIsWLDgWGtNSzva+/nsP26ge3CYH39a0/dFZOYk04c+Wb+AT9i+Gfipu9cD1wE/M7N3fba73+/uDe7eUFNTc+zVppmm1j4+/r/XMzw2zqNfuFhhLiIzKpkWejMwP2G7nnd3qdwGrAFw9/VmVgBUAx3TUWQ62tU5wKcefImivGwe/cLFzJ+lFRJFZGYl00LfACw1s8VmlgfcBKybcMw+4CoAM1sGFACd01loOtnfPcQtP34JM/j55y9UmIvISTFloLv7KHAH8BSwldholi1mdo+ZXR8/7OvA7Wb2GvAwcKu7T+yWyQgHQhFueeAlwiNj/Oy2C1lSU5LqkkQkQyQ1scjdnyQ2FDFx390Jr5uAS6e3tPTTNRDllgdepGsgys9vv4hlc8um/kMiItNEM0WnSSg8wqcefJmW3jAPfXY1q+ZXpLokEckwuk38NAgPj/HZf3yZNzv6+T+fauDCJVWpLklEMpBa6CfI3fnW42+waX8vP/rk+bzvjOAPxxSRU5Na6Cfo0cb9/PKVZr585VKuPXtuqssRkQymQD8Bb7SE+KvfbOGypdV89aqlqS5HRDKcAv04hcIjfPHnrzCrKI8ffGKVbhcnIimnPvTj4O58419eo7U3zC++cBFVJfmpLklERC304/Hj3+/mmaZ27rpuGRcsnJXqckREAAX6MXtpdxd/9+/buXZlLZ+7dFGqyxEROUyBfgw6+6N8+eFNLJhVxN/feA5m6jcXkVOHAv0YfOeJJnrDI/zolvMpLchNdTkiIu+gQE/S+l1drHutlT9732lao0VETkkK9CSMjI3z1+veoL6ykC9eflqqyxERmZQCPQkP/fEtdrQPcPeHllOQm53qckREJqVAn0JHX4QfPPsml59Zw9XLdQs5ETl1KdCn8L1/28bw6Djf/pMVGtUiIqc0BfpRvLynm19vamHte5ewqLo41eWIiByVAv0IRsfGufs3b1BXUciXrjg91eWIiExJgX4EP3txL9sO9PNXH1pGYZ4uhIrIqU+BPonO/ij/8PQO3ntGDR9YUZvqckREkqJAn8S9z+4gMjrGt/9kuS6EikjaUKBPcHAgymMbm7nxgvksqSlJdTkiIklToE/wf1/cy/DoOLe9Z3GqSxEROSYK9ASRkTF+tn4vV501m9Nnq3UuIulFgZ7g8U0tdA0Oc9tlap2LSPpRoMe5Ow/8YQ8r5pVx8ZKqVJcjInLMFOhxL+zoZGfHAJ+/bLFGtohIWlKgxz34+z3MKcvng2fPS3UpIiLHRYEObG3r4w87D3LrJYvJy9FfiYikp6TSy8zWmNl2M9tpZndO8v69ZvZq/LHDzHqnv9SZ88Dv91CYm80nVy9IdSkiIsctZ6oDzCwbuA+4GmgGNpjZOndvOnSMu/95wvFfBs6bgVpnREdfhHWvtfDJ1QsoL9J9QkUkfSXTQl8N7HT33e4+DDwC3HCU428GHp6O4k6Gh9a/xei48zlNJBKRNJdMoNcB+xO2m+P73sXMFgKLgeeO8P5aM2s0s8bOzs5jrXXaDQ2P8vOX9nHN8jksrNJ65yKS3pIJ9MnG8PkRjr0JeMzdxyZ7093vd/cGd2+oqalJtsYZ88tXWugdGuH2y5akuhQRkROWTKA3A/MTtuuB1iMcexNp1N3ys/VvcW59ORcsrEx1KSIiJyyZQN8ALDWzxWaWRyy01008yMzOBCqB9dNb4szY1TnAjvYBPnJenSYSiUggTBno7j4K3AE8BWwFHnX3LWZ2j5ldn3DozcAj7n6k7phTytNb2gG4RjewEJGAmHLYIoC7Pwk8OWHf3RO2vz19Zc28p5sOcHZdOfMqClNdiojItMjIaZHtfRE27evlAyvmpLoUEZFpk5GB/kyTultEJHgyMtCf2nKAxdXFLNVNLEQkQDIu0EPhEdbv6uKa5XM0ukVEAiXjAv2F7R2MjjvXqP9cRAIm4wL96S3tVJfkc958TSYSkWDJqECPjIzxwvYOrl4+h6wsdbeISLBkVKD/cddBBofHNFxRRAIpowL96S3tlOTncPFpugm0iARPxgT62LjzTFM7V5w1m/yc7FSXIyIy7TIm0F/Z10PX4DDXLFd3i4gEU8YE+lNvHCAvO4vLz0z9OuwiIjMhIwLd3Xm6qZ1LTq+itED3DRWRYMqIQN92oJ993UNcs1xrt4hIcGVEoD+9pR0zeP/y2akuRURkxmREoD+15QDnL6hkdmlBqksREZkxgQ/0gwNRmtr6uGqZWuciEmyBD/TXW0IAWrtFRAIv+IHeHAv0lXVlKa5ERGRmBT/QW0IsqSnWcEURCbzgB3pziHPqylNdhojIjAt0oHf0RTjQF2GlAl1EMkCgA/3QBdFz6itSXImIyMwLdKBvbg5hBivm6YKoiARfoAP9jZYQp9eUUJyfk+pSRERmXGAD3d3Z3BLi7Hr1n4tIZghsoLf3Rensj3K2LoiKSIYIbKBvbu4F4By10EUkQwQ20N9oCZFlsHyuAl1EMkNSgW5ma8xsu5ntNLM7j3DMfzWzJjPbYmb/PL1lHrvNLSHOmFNKYZ7uHyoimWHK4R9mlg3cB1wNNAMbzGyduzclHLMUuAu41N17zCylSxu6O683h7jyLK2wKCKZI5kW+mpgp7vvdvdh4BHghgnH3A7c5+49AO7eMb1lHpvWUISuwWGNcBGRjJJMoNcB+xO2m+P7Ep0BnGFm/2lmL5rZmsk+yMzWmlmjmTV2dnYeX8VJOLTCoka4iEgmSSbQbZJ9PmE7B1gKXA7cDDxgZu+ab+/u97t7g7s31NTUHGutSXu9pZecLGPZXM0QFZHMkUygNwPzE7brgdZJjvmNu4+4+x5gO7GAT4nNzbELogW5uiAqIpkjmUDfACw1s8VmlgfcBKybcMzjwBUAZlZNrAtm93QWmix35/WWkLpbRCTjTBno7j4K3AE8BWwFHnX3LWZ2j5ldHz/sKaDLzJqA54G/cPeumSr6aJp7wvQOjeiCqIhknKRWrXL3J4EnJ+y7O+G1A/8t/kipt5fMVaCLSGYJ3EzRzc0hcrONM2tLU12KiMhJFbhAf72ll7Nqy8jP0QVREcksgQr0QzNEdcs5EclEgQr0fd1D9EVG1X8uIhkpUIG+WTNERSSDBSrQX28JkZeTxRlzdEFURDJPoAJ9c3Mvy2pLycsJ1GmJiCQlMMk3Pu5saenThCIRyViBCfS3ugbpj45yTt271gQTEckIgQn0XZ2DACydU5LiSkREUiMwgd4WCgNQV1GY4kpERFIjMIHe2hshN9uoLslPdSkiIikRmEBvC4WZU1ZAVtZk9+MQEQm+4AR6b4R55epuEZHMFZhAbw2FmVtRkOoyRERSJhCBPj7utPdFmKsWuohksEAE+sGBKCNjzjy10EUkgwUi0FtDEQC10EUkowUi0Nt6Y2PQ55arhS4imSsQgX6ohT5Pk4pEJIMFItDbesPk52RRWZSb6lJERFImGIEeijCvohAzTSoSkcwViEBvDYXVfy4iGS8Qgd7WqzHoIiJpH+ijY+N09Ec0Bl1EMl7aB3p7f5Rx1xh0EZG0D/QD8XXQtY6LiGS6tA/01t74GHS10EUkw6V9oLephS4iAiQZ6Ga2xsy2m9lOM7tzkvdvNbNOM3s1/vj89Jc6udbeCCX5OZQVaFKRiGS2nKkOMLNs4D7gaqAZ2GBm69y9acKhv3D3O2agxqNq0xh0EREguRb6amCnu+9292HgEeCGmS0reW2hCHO1houISFKBXgfsT9huju+b6GNmttnMHjOz+ZN9kJmtNbNGM2vs7Ow8jnLfrbU3wjy10EVEkgr0yRZI8Qnb/woscvdzgGeBhyb7IHe/390b3L2hpqbm2CqdRHR0jIMDUY1BFxEhuUBvBhJb3PVAa+IB7t7l7tH45o+BC6anvKNrD8V+pEa4iIgkF+gbgKVmttjM8oCbgHWJB5jZ3ITN64Gt01fikbXGhyxqDLqISBKjXNx91MzuAJ4CsoGfuPsWM7sHaHT3dcBXzOx6YBToBm6dwZoP0xh0EZG3TRnoAO7+JPDkhH13J7y+C7hrekubmmaJioi8La1niraFwlQU5VKYl53qUkREUi69A13roIuIHJbWgd4a0hh0EZFD0jrQ20JhXRAVEYlL20APD4/ROzSiLhcRkbi0DfTDY9DVQhcRAdI40NviQxbVQhcRiUnbQNcsURGRd0rbQD/UQp9Tnp/iSkRETg3pG+ihMNUl+eTnaFKRiAikcaC3hiK6ICoikiBtA72tV7eeExFJlL6BHtK0fxGRRGkZ6H2REQaio+pyERFJkJaBrjHoIiLvlpaBrlmiIiLvlpaBrha6iMi7pWegh8JkGcwu1aQiEZFD0jLQW3sjzCkrICc7LcsXEZkRaZmIbSGNQRcRmShNAz3C3Ar1n4uIJEq7QHd3WnvDuvWciMgEaRfoPUMjREfHNcJFRGSCtAv01l6NQRcRmUzaBXpbSGPQRUQmk4aBHmuhz1ULXUTkHdIu0GvLCrh6+RyqizWpSEQkUU6qCzhW16yo5ZoVtakuQ0TklJN2LXQREZlcUoFuZmvMbLuZ7TSzO49y3I1m5mbWMH0liohIMqYMdDPLBu4DrgWWAzeb2fJJjisFvgK8NN1FiojI1JJpoa8Gdrr7bncfBh4BbpjkuO8Afw9EprE+ERFJUjKBXgfsT9huju87zMzOA+a7+xNH+yAzW2tmjWbW2NnZeczFiojIkSUT6DbJPj/8plkWcC/w9ak+yN3vd/cGd2+oqalJvkoREZlSMoHeDMxP2K4HWhO2S4GVwAtm9hZwEbBOF0ZFRE6uZAJ9A7DUzBabWR5wE7Du0JvuHnL3andf5O6LgBeB6929cUYqFhGRSU05scjdR83sDuApIBv4ibtvMbN7gEZ3X3f0T5jcxo0bD5rZ3uP5s0A1cPA4/2w6y9Tzhsw9d513ZknmvBce6Q1z9yO9d8oys0Z3z7gunUw9b8jcc9d5Z5YTPW/NFBURCQgFuohIQKRroN+f6gJSJFPPGzL33HXemeWEzjst+9BFROTd0rWFLiIiEyjQRUQCIu0CPdmlfNOdmf3EzDrM7I2EfbPM7BkzezP+XJnKGmeCmc03s+fNbKuZbTGzr8b3B/rczazAzF42s9fi5/038f2Lzeyl+Hn/Ij65L3DMLNvMNpnZE/HtwJ+3mb1lZq+b2atm1hjfd0Lf87QK9GSX8g2InwJrJuy7E/ituy8FfhvfDppR4OvuvozYMhJfiv8bB/3co8CV7n4usApYY2YXAX8H3Bs/7x7gthTWOJO+CmxN2M6U877C3VcljD0/oe95WgU6yS/lm/bc/XdA94TdNwAPxV8/BHz4pBZ1Erh7m7u/En/dT+yXvI6An7vHDMQ3c+MPB64EHovvD9x5A5hZPfBB4IH4tpEB530EJ/Q9T7dAn3Ip34Cb4+5tEAs+YHaK65lRZrYIOI/YTVMCf+7xbodXgQ7gGWAX0Ovuo/FDgvp9/wHwl8B4fLuKzDhvB542s41mtja+74S+5+l2k+ijLuUrwWFmJcAvga+5e1+s0RZs7j4GrDKzCuDXwLLJDju5Vc0sM/sQ0OHuG83s8kO7Jzk0UOcdd6m7t5rZbOAZM9t2oh+Ybi30qZbyDbp2M5sLEH/uSHE9M8LMcomF+c/d/Vfx3Rlx7gDu3gu8QOwaQoWZHWp4BfH7filwfXzp7UeIdbX8gOCfN+7eGn/uIPYf+GpO8HueboF+1KV8M8A64DPx158BfpPCWmZEvP/0QWCru/9DwluBPnczq4m3zDGzQuD9xK4fPA/cGD8scOft7ne5e3186e2bgOfc/RYCft5mVhy/DzNmVgxcA7zBCX7P026mqJldR+x/8ENL+X43xSXNCDN7GLic2HKa7cBfA48DjwILgH3Ax9194oXTtGZm7wF+D7zO232q3yTWjx7Yczezc4hdBMsm1tB61N3vMbMlxFqus4BNwJ+6ezR1lc6ceJfLN9z9Q0E/7/j5/Tq+mQP8s7t/18yqOIHvedoFuoiITC7dulxEROQIFOgiIgGhQBcRCQgFuohIQCjQRUQCQoEuIhIQCnQRkYD4/4uPZ4TkF57aAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cstate[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_params(opt_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', family='serif')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "689f361a419f4445a4b4ea3166afcef9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARkAAAEYCAYAAABoTIKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2deXhU1fnHP29CIIBhDXsgARL2HWSpFrG4dbFWra1Va3GtrUv9Wdy6WK224lLX1oW60KK1m1ZcimKLC1aQHSJ7gAAhBBKWsISQ7f39cU9gjMlkkszkTibv53nmmXvPPfee770z884571leUVUMwzAiRZzfAgzDiG3MyBiGEVHMyBiGEVHMyBiGEVHMyBiGEVHMyBiGEVGahZERkTNFZKWIqIh8KCJdXPotIpItIoUi8tcIlPtvEZnitl8TkeLK/QiUdYmIPBeB6/5IRNaLyAER6RWQPl5EFolInog8KyKtReQDd48PVbnGUy7fIhGZVE8dF4rIQlfGxyLyNxEZ19D7q6XMkJ6piHzTPaMPajhe9fv3kYhsEpHZItI2hOtf676ns+p+F8evsVpE0kPIF/7vkao2ixcwBVCgRZX0u4GPI1RmO0AC9rOBKW47zXv8YSsrHjgpQvcxDSgD3qySngbMqpKW5fKOr5I+C0irZ/npwB6gl9uPA54Hpkf4OxPyM3XP6INQv39AR2A3cHeI17+76rMOkndW1esCHepzz9Vdq66vZlGT8QtVPajuk2qEsspV9XAEi3gcOEtELq0lXw7wD+AFEWkZprJHA9mquhNAVSuAR4H1Ybp+tUTymarqfmABENHaWEB5B0LMF/Z7NiNTBREZKiJzXbV2oYhcG3Cssslzu4i8ISJbROQiEblcROaLSKaIZLi8t7omwt3VlNEe+Kvb/sC94kUkQUQeEpFP3OshEUmopuzXRSRLRH7sjg131fHsgDKuddeYLyL/FZEhLn18ZV6n8UNXlR5Qy6PJBH4NPC4iXWvJewOQDPyylnyhsg0YLSLnVSao6meq+hZ87tn8WkTecZ/DoyISX5nf3esiEVkgIk8GGkARudQdm+9eU+vyTBtACzyjXHn9/iIyz30mC0TkSzWdKCK/dN/P+SLyloj0dOk/Ac4Bprnv1VUi8jvxmrvTXNNuj4hsFZHvunPed9/VGwPvuYZrbRKR/SJyn8vzOxHZKyK31niXkaxuRtOLE9XVD4EPAl7ZBDSXgAnABLedAKwDMgKOZwNPuO1v4VXjz3f7TwDP1FTVpJbmEt6P8j94VdZ44F3gl1XOf8ptjwcOc6L6PQXv374y7w+BVgHHFlR5FiXAqW7/KeDZIM9umnvFA0uAfwTcw6wqeT9w7+e5MkYFPIu0Bnx+TwAVwFrgLqBnlePZeIZbgERgFXCtO3ap+xzbuON/B37hjn0JyAO6uP1vV95THZ/pNOrWXOoDzOFEEzDeabzS7Y8ACoAkt3934LMGbsQ1xV3Zs2v63lV+LsA0t30TMC/g2HnANTXc8+euBYwBCoHWbr8r8Hywz6451mSmquqUyhfeQwxkE3CViPwPeA/ogVddD+Q99/4Z0AXPMACsBvo1QNvlwJ/Vq7KWA38GrqiS552AstrifcjVsRZ4U0QWADOAsVWOH1bVjwOu1bc2cU7TD4BzReTCWvLOAf6G12xqUdu1Qyj7JmAQ8DpwFbBJRKZWyfY39SgG/glc7NKnAX9V1SL1fhmvAN93x64A/q2q+W7/deDpGmTU9kxD4b8isgqvqfeeuiYgMBHoD8x297sa2Al8o4br7ADeF5GPgJvrqOUvwGQ54ci/CM/w1oqqLge24xkm8Ax40E6T5mhkauMRvB/uZGeEVuL9AwZyyL2XAahq4H5D/BApQH7Afr5LC+SgK7PY7X+hPNccewuYqapfxvuxta7uOo7i6q5THaq6FvgV8AegUy3ZbwK6AbcHyyQijwU0G0cFKXujqv4Mz5C/7HQEsj9gey/eHwR4z/CSyjKcnoqAY8efuaqWqeqn1WgM5ZmGwlS8P63fAw8FND1T8Go67wXobAW0r0ZLBp5RuFVVJ+MZmarf0RpR1QJgHnCZiHQEylW1sA73MBvvD7Hyfv4bLHOD/2FikPF4TZJyt5/QiGXvwKsZVdKFgDZ7HRiI17NVWesJ9z08jNdUfBTYWlMmVd0vnk/rVTxjXVO+m4MVJiITgMGqOsvlLxeR14H7q2QNNHrJwC63vQOv1nC8a11EkgOOdQlIbwEMVdVVVa4dtmeqqhXOVzcNrwl2r9NR6v7YKrW05YQxDGQ0cFBVlzRAy5/xfGyFhFiLCeAl4F4ROQNYp54jvkasJvNFsvD8MohID7y2cSQ45MpoIyJ3iMhEvKbbZc4JHAdcBrxYj2tvw6tVTXD754RB73GcAZ6GZ5Bry/s2XvNkQm15g9AauNb96HDP5nzgoyr5LhSPRLwmQGU1fhZwkUtHvLFKzwYc+1qA0fku3r1VJazPVFWL8HrsfiSec/9TYLuIXOA0tsBrulXnkM8COgY466tqOQS0EZG2IvJyDRLeBLoD1+L5/mriC9dS1Vw83+afcc27oNTXEdeUXsCZeP+klY7fSiffLXgOw0K8Njt47f6lwELgBTx/xXrgK+6hFrtrDQQWuWvOwfvBrQcOAA8Ct+I5FLPxfAivBZw71pX1MrAMz8fTGu8f6QHgE/d6GEhweQPL7uuup07DaS69mBNO2etc2W/h1TgUr4o8JCDvs1V1V/PsfuSOr8c5UgOO3cIJJ2lrPOfiAffeMiBfe7x/6rR6fn5dgMfcvb4PLAZm4pyiLk82cAcwF89X9hgQH3D8p3g/5Pl4P95uAccuc5/3B3i+nHbA8Do8028GPMMnQ/j+DQl4LoV437dT8Xwy77g8CzjhBL7WlZsH/Nyl3evS5rjPsRjPnwcwyelZDFwC/M5pWw98PUDX08CjAfvV3fPnrlXlma0O5fOr9E4bRpPGdbtOU9UPfJbSLBCRrwLDApugNWHNJcMwQkZEKh2+l+HVxGvFjIzR5BGR1/D8C4+JSH26lY3Q+YaIrAA2O99MrfjWXBKR7sB9wEhVPbma43HAb/EGnKXiDfhZ5I6dAVyANxBOVfWeRhNuGEad8LML+1Q8p1VN4yK+A7RT1TtEpBOwSEQG440deAavm/GYiLwqIlNVNWhfvWEY/uCbkVHVf0rwZQ++jue5R1X3iUgxMBSvp2Gbqh5z+f7n8n7ByLgxGtcCtG3bduygQYPCdwOGYQCwbNmyAlXtUtPxaB6M15UTI2vBG6HaFc/IVJf+BVR1Jl5XJ+PGjdOlS5dGRqlhNGNEZFuw49FsZPYASQH77Vya1pBuGEYUElW9S25UYWW16228gUA4n0wisAZv0FSqiLRy+U5xeQ3DiEJ8q8mIyGl4M2F7iMgv8EYlTsMbdXgd3nyK0SLyK7xp8ZerN5y9SER+BDwhIvl4ow7N6WsYUUqzGfFbnU+mtLSUnJwciouLazjLqEpiYiIpKSkkJDTmvFEjmhGRZapa4wp/0eyTiTg5OTkkJSWRlpaGiPgtJ+pRVfbu3UtOTg59+9a6/IxhAFHmk2lsiouL6dy5sxmYEBEROnfubDU/o040ayMDmIGpI/a8jLrS7I2MYRiRxYyMj3z5y19m+vTpXHHFFbRv357p06czffp0pk2bVqfrzJo1iwMHvIgXH3/8MWPGjOGDDz4Iv2DDqAfN2vHrN1deeSVXXHEFn332Ge+//z4PP/wwAC++WLfF8GbNmsWUKVPo0KEDp556KiNGRGoxP8OoO2ZkHPe8uYa1uQdrz1gHhvRsx6/OHVrj8SuuqBqIwGP9+vX06dOHa665hk8//ZSUlJTjztZZs2Yxc+ZMfvvb35Kdnc28efPIzs7mscceY9CgQVx33XUAzJs3j1dffZXly5fz8ssvk5aWFtZ7M4xQseZSFPLAAw+wZ88ebrjhBubMmcN11133uSbUtdcejzfHWWedRVpaGjfffPNxAwPQt29fnnzySc4//3xeffXVxpRvGJ/DajKOYDUOP+jWrRsdO3YEYNSoUXX2saSne7HVk5OTyc7ODrM6I5YpOlzIvrztHMrPIW3EqbRum1T7SUEwIxOlVO0qTkpK4uBBrzm3ffv2zx2Lj49HVcnMzGTIkCHVnm8YWlHBgb272Zu7hcN7tnFs33a0cCfxR/JoXbyHdqX5dCzfR5IcPR7EKavN26SPPLVB5ZqR8ZmjR48yc+ZMCgsLeeGFF7jyyit57rnnKCws5JFHHuGWW24BvNpMRUUF9957L6mpqRQWFvLKK6/wve99j3POOYcZM2ZQXFzM9ddfz+rVq5k9ezYDBw7kzTffZP/+/WRlZR2v3RixSaURyd+xkUO7NlFSsIW4gzm0PrKT9iV5dC3fQ0cpoWPAOaUaz17pxIEWyRS06U9u6y+hSd2Jb9eD1p16kZbW0HDfzXzu0rp16xg8eLBPipou9tz8pXBfPru2rObgjrWUF2ymZeFW2h/dQfeyXE6So5/Lu592FMR35VBiD0pO6gXtU2jZqTcnde1Lp+5pdOqWQlx8fIP02Nwlw2ii7M/fRe6mFRzesRr2rCPp0Ga6lWynM4XHY9eWaRx5cV3Z16o3azqOQjv1JbFLf9r37E/XPgPpmNThczUXPzAjYxg+oxUV5GZvIG/9Ikp2rqTNvrX0OJpFV/YdNxAHaUNuQhqbO57Kpk7pJPYYTOfUIXRPHURKy1ZfCJgeTZiRMYxGpnDvbrJXvU/R5kW03buaPsUb6MVheuH5SHLie7Ot/Ti2dBlKm5ThdM8YQ5ceqQyKa5ojTszIGEaE2Z2zme3L3kWzP6Zb4SpSK3IYidfU2dYilY0dp6A9R9MpYwJ9Bo2lb2IbYmkhDTMyhhFmDhXuI2vRW5RueI+e+5eQorvoBhykLVtbD2NXt/NIGnAqfUecQv+T2tPfb8ERxoyMYYSBHVmZ7Pzk77TLeZ+MY2sZLeUc1tZktRlJTsr3SB52Bv2GTWRkA3tymiJmZHxm8eLF3HbbbZSUlPDII48wceJEvyUB3mzum266iUceeYQpU6b4LScq2bZuGbkL/0b3nHfpW5FNb2BLXBpLe15Ku+FfZcC4qYxq2arW68Q6fi4kHjTUrIg8D5+rSY4AxqhqtohkA9kufaeqXhp5xZFh/PjxTJkyhcOHD0eNgQFsNncN7Nuzk43/eZHkza+RXr6Z3ipsaDmERenTSTv1Yvr1yaCf3yKjDF+MjIi0ofZQs/NU9W8ufztglqpmu2OzVPXusIqaewfkZYb1knQfDl+dUa9Td+7cyZ133smwYcPIysrihz/8IWPHjuX1119nzpw5DBw4kMzMTJ5++mnatWvHd7/7XbZs2cKZZ57JggULuOCCC3j33XeJj49nxIgRLFq0iEsuuYRrrrkGgLvuuouysjLi4+NJSkritttuA+Cmm26itLSUfv36kZOTE7ZH0ZTRigo++3gOZYtmMuzIp0yUcjbFZ7Bo4O2kT7mMwT36+C0xqqnVyIhIS1UtcdvJQIKq7mpguZOoJdRspYFxXAW8ELA/WURuwwvyNldVP2mgnqhj+vTpnHvuuVxyySVkZ2dz/vnns2LFCjp27Mhjjz1G+/bteeSRR5g9ezbXX389DzzwAKeccgr33HMPxcXF7Nq1i9GjR3PnnXdy//33k5+fz1e+8hWuueYa3n33XRYtWsS8efMAmDJlCmeddRY7d+5k06ZNzJ07F4A5c+b4+Qh8p+hwIZn/nkn39X9ieMUO9tKepT0upvvkK8kYMo4MvwU2EUKpydwB/NpttwQeBC5rYLk1haD9AiISB5wNPBaoSVUXuxrRchH5hqpmVXPu8VjYffrU8m9TzxpHpFi9ejVdu3Zl+/btqCpdu3aloqKCk046iV//+tckJyezfPlyhg49MXs8PT2dhIQEEhISSEpKIjc3lwEDBgDQpUsXDh06dPzaRUVFzJjh3XPv3r3Jz89nzZo1ZGSc+On069c8K/6HD+4n89UZDNk2mwkcYVN8OktG3c+Ic6YxKbFN7RcwPkeNRkZERgCjgFEicrlLjgPahqHcmkLQVsd5wFsaMMlKVRe79yIRWYkXRfILRqZqLOww6G4UMjMz2bdvH1OnTuWb3/wmqkqvXr2Ii4vj6quv5vHHH2fy5MnMnDmT3Nzc4+dVN/O6urSRI0eycOFC7rjjDgDmz59Peno6x44dY/78+cfzbdmyJQJ3F70UHS5k1WsPM2jLi0ziECvaTCLxtFsYdPIZSBMdCBcNBKvJdAT6BrwDlAOPhqHc46FmXZPpFOApF462TFUDl6ibBhx37IrIVLwm2zsuKR3YHAZNvrB06VI++ugjSkpKuO+++wDYunUrN954I2+88QaZmZnk5eVx+umnA3DVVVdx7733cvrpp7Ns2bLjM6xnzZrFtm3bjs/kPnbsGLNnz2b16tUsXbqUzz77jMLCQl599VUuvPBCFi9ezJ133kmLFi0oLi5mxowZ9O7dm7lz53L11VfTu3dvVJXZs2czduxYkpIatqZINKMVFSx94yn6rXyQSRSyKvFk8s+6i9FjJvstLSaodRa2iGSo6qaA/XgXLrZhBYucCXwbyAdKVfUeEXkQ2KeqM1yeUcClqnprwHnDgbuBZUBPvN6l+2srz2Zhh49Yem7Z65ZS9NpPGFL6GetbDIaz72PQyWf4LatJEY5Z2FkiMpoTzZvvA9c0VJiqvge8VyXttir7K4GVVdIygQsbWr7RvCk+eoSVf76dsbl/4Yi0ZvHwexj3rRsbvOyB8UVCMTJvAgoUuP3hkZNjGJEnJyuTY69czsTyLSzu+DUyLn2E8V16+C0rZgnFyOxR1Ssrd1ytJmZQVVuqsg409UXOlv/7eQZ8+nPKJY6VX36W8VMv9ltSzBOKy3yhiAT2ZY6MlJjGJjExkb179zb5H05joars3buXxMREv6XUmdKSYyz+/TTGLL6FnQmpHL3yQ0aZgWkUgnVh7wMOAAL8TETUbbcDZjWKugiTkpJCTk4O+fn5fktpMiQmJpKSEs1LJH2R4qLDbHjyAsYf/ZRF3b7H2KseJ8HmFDUawZpLN6jqX6omisi3I6inUUlISKBv31haucOoypGD+8n+/XkMP7aaT4f+gonfubX2k4ywUmNzqdLAiMg3qqT/M9KiDCMcFO7LJ+eJsxl4LJNlY2cwwQyML4Ti+H1SRG4J2Fe8wW/3qOrOyMgyjIZRuC+fvX84g75lOaw+5UlOPquhM2GM+hKKkXkFb+LiFqAfMB74GLgL+GHkpBlG/SgvK2P7zO8ysGwH67/yPGNOO99vSc2aUHqXjqrqf1V1q1uKoaWqLgA21XaiYfjBkuduYnjxMlaM+CUjzMD4Tig1mZNFZBzeBMQBwHgRaQkMjKgyw6gHy958hol5L7Mo+XwmXvh/fssxCM3I/Ap4FhgKrAF+BAwBPoygLsOoM1krFzB06S9Y23I4Y6551m85hqNWI6OqK/D8MACISA+3aNXKms8yjMbl4N49nPT6NA5Ie7pd/VdatrJxMNFCsMF4I1V1VcBaMpWcC1wUWVmGUTfWvXI743Qvm8+bw4BuTWuwYKwTzPF7k3u/Am89mcpXp0iLMoy6sHn1Qsbl/4slXS5gwJjT/JZjVKHGmoyqXuU2b3LLKwAgIkNrOMUwGh2tqKDkzZ9yUJIYfEl0LaFqeNRYkxGRPiLSByis3Hb7TTb8iBF7LH97JoNL17Bx+C2071TtMtGGzwRz/H6AF9uo6joIfYCfRUiPYYTM4YP76bNsBpviMxj3rZtqP8HwhdomSP67aqKIfC2CegwjZD575edMZD97v/oC8baiXdQSbILkFwxMsHTDaExyszcwJvevLO7wNQaN+4rfcowgWJwHo0mybd7viaeCPhfe67cUoxaCjZO5HXg4HJEJarh+bbGwpwHXAcUu6XlVne2OXQaMxgvRsllVbXhnM6LkWDEDc19nddtJjO6d7rccoxaC+WQSVLVcRK5T1WcqE0XkNFVt0JSCEGNhA1wcEP+68twUYDowWlVVRJaIyPzAsC1GbJP5n5cYy0FyTr6q9syG7wQzMgNdbeIsESkKSD+Xhs9bqjUWtuMGEckD2gC/V9V9eCFrlwVElFwIfBWbFd5saL1qFrnSjWGTbYZ1UyCYT2YGXvC0ygiS4RzxG0os7A+BB1T1YWAp8I86nAt4sbBFZKmILLV1fGODbetXMKQkk219v2MxkpoIwUb8rgHWiMibERjxW2ssbFXdGrA7H3hDROJdvvQq534hDra7RpOMhW3UzK75T9FD4xlw9nV+SzFCJJTepS0icp+IvCki9+IN0Gsox2Nhu/1TgLdFpJOItAMQkftFpNIIZgBbnRP6XWCsnAiWNAmYGwZNRpRTXHSIIXveJrPdaXS2SZBNhlDWk3kEb03fF/F+7I8C1zakUFUtEpEfAU+ISD6wWlX/WxkLG6+plgc8LSJb8aJWft+dmyMiDwOPikg58Jw5fZsHq9+dxXiOkPilBn39jEYmFCOzWVUfrNwRkZ+Ho+DaYmGr6uNBzn0JeCkcOoymQ4c1f2ZbXG+GTDjbbylGHQiludTb+UJwzZdekZVkGF9k2/rlDCjbyK70i5E4G0PalAjl03oPyBaRlXgRC8z/YTQ6u5Z7s1lST/mOz0qMuhLK8ptviMhHeD06Wap6IPKyDOPzJOZ8TI70ICV1gN9SjDoSUr1TVQ+o6lIzMIYflJWW0P/ISnZ2HF97ZiPqsMatEfVkrVpAkhylRfoUv6UY9cCMjBH17P/M64TsO+4cn5UY9aHORkZEbKil0ai0y/2EzfH96NS1p99SjHoQbI3ffSKyRUR2iMgREcl2EyV/1Yj6jGZOcdEhMo6tYU/yBL+lGPUkWE3mBlXtBzwIdFXVNLyJiH9oDGGGAZC19L+0lDLaDJzqtxSjngRbfvMvbrObqh5xaYfxJiQaRqNweP1/KdV40k8+028pRj0JZVrBYBG5FW+9lgF485cMo1HovGchWS0HMTipg99SjHoSiuP3GqCLe08Gro6oIsNwFO7Pp39pFge6f8lvKUYDCGXE7z4RuQPoDOxV1YrIyzIM2LLkHUaL0n6I+WOaMrXWZETkLLw5S88D3xORH0ZclWEAJRvnU6StSB9zut9SjAYQSnPpXGAQ8D9VfRnoH1lJhuHRfd9iNrUeQctWiX5LMRpAKEYmR1WLgcrlK23+khFxCnZtI7Uih6Jep/gtxWggoRiZAc4nM0REbsDWkzEagdz1iwFonz7RZyVGQwnFyNyMNzYmGegO3B5RRYYBFO1cA0DPjNE+KzEaSii9S4dE5Bd4oVD2We+S0RjEFWyggA4kJ3f3W4rRQOrSu/QC1rtkNBIdDmeR1yrNbxlGGAhlxG9l79JPVPVlF1GgwYQQC/t2vOZZHjAWuEtV17tj2ZwIzbJTVS8NhyYjOtCKCnqVbiez4zf8lmKEgVCMTI6qFotI2HqXQoyFfRJwi4t3/V3gITyDBzBLVe9uqA4jOtmdk0V3KUa6DPZbihEG/OpdqikW9nFU9ZcB8a7jgMMBhyeLyG0icq+I1Djm3MLUNk12b14FQFKfYT4rMcKBX71LdYln3RL4AfCLgOQ7XCyo+4EXRCS9unNVdaaqjlPVcV26dAmDbKMxOFrZs5RuPUuxQKi9Sw/hjfTd7JZ7aCi1xsKG4wbmaeDnqro5QNNi917kQrWcQg3xsI2mR3zBetez1M1vKUYYCKV36fvAGrzepbUicnkYyg0lFnZr4FngEVVdJiIXuvSpIhK42Gs6XhhdI0Zof3iz9SzFEKE4fr8JpKlqiYgkAn8B/tyQQkOMhf0yMAzoKyIAbYFX8Wo8d4vIGKAn8KqqftwQPUb0oBUVpJRuI7PjubVnNpoEoRiZFapaAuB6mZYAiEgvVd1Z34JDiIV9QQ3nZQIX1rdcI7rxepaOgfUsxQyhGJlhIvJrvAF5/YCursl0LnBRJMUZzY/dm1fSHUjqM9xvKUaYCKV3qQdQDqS6911AX7xpBoYRVo7mfAZAr4xRPisxwkUoNZmbXBPlc4jI0AjoMZo58Xs3ej1Lna1nKVYIFndpqIh8SVUzRaStiDwoIo+JSFcAVV3TeDKN5oL1LMUewZpLD+N1DwP8BhgBbAcejbQoo3miFeWklG7jcDsLiBFLBGsuLVPVyq7q7wJjVTVXRO5vBF1GMyRvRxY95BjS1XqWYolgNZkyABGZiDfTOdelF0dcldEsyd+8EoB2NmcppghWk+nvJkReBrwIICIpeAPkDCPs2Gp4sUmwmsyteN3W/wCecmn3Av+OtCijeRK/dwP5dKR9p2rnyhpNlGA1mUl4ExNLKhNU9YrISzKaKx0Ob2Z3q1RsvnxsEawm0x+YKyIvisjZIhLKwD3DqBcV5eX0Kt1uPUsxSI2GQ1UfVtWpeN3X44EPROQpEZncaOqMZsPunCzaWM9STFJr7URVs1T1XlWdjLdk5ldF5AeRl2Y0J/ZuWwfASb3MyMQaoUwrQETOxVtYaiVwr6oWRVSV0ew4krcRgOTUQT4rMcJNKItWPYgXVWAy0BJvrRfDCCu6dwtHtSVduqf6LcUIM6E4cw+4XqUtqroCb1EpwwgriYe2kRffg7j4eL+lGGEmFCOT7N4rIwck1ZTRMOpLh+IdHEhM8VuGEQFCMTIbRWQtcLmILAbWRliT0cyoKC+nR3kex5L6+C3FiAChRCt4RkQ+BIYCmaq6IfKyjOZE/q5sukkp0rm/31KMCBBS75KqrgPWgRcwTVVnNrTgEMLUJuItN7ETyABmqOpGd+wyYDTeSn2bVfXZhuox/KNg23q6AW26D/BbihEBajQyIrIPLyStcMIfI3hd2Q0yMiGGqb0Z2K6qD4rIcOB54MtukuZ0YLQLYbtEROar6qaGaDL8o7L7unMf676ORYL5ZG5Q1X6q2te991PVvsCNYSi31jC1bn8hHI9QMNLFZDobb62bSsO3EPhqGDQZPlFesIUSjadbijWXYpFg0wr+AiAi40TkZLc9EXgjDOWGEqa2pjx1CXFrsbCbAK0OZpMX1534FiG13o0mRii9S7cCR9z2EeC+MJQbSpjamvKEFOIWLBZ2U6F98Q72Wfd1zBKKkVmiqmvheLNlVxjKrTVMLTIwg9QAABX9SURBVPA2XrMK55NZpaoHgXeBseLCSro8c8OgyfABraigR1kuxSfZSN9YJZT66QARSVbVAhFJxuvpaRAhhql9HHhYRH6Bt6D5Ve7cHBF5GHhURMqB58zp23TZuyeHZDkGnfr5LcWIEKEYmVnAChFJAgqBi8NRcAhhao8C19dw7kvAS+HQYfhLwbZ1JAOtu9s6MrFKKIPxPgF6B9RmWjeCLqOZcGiX133dqbd1X8cqwcbJVBqVyQFpAN8HrmkEbUYzoLxgM2UaR7feVpOJVYLVZGbjjT95HG8dmUosEroRNhIKs8mL60pKq1a1ZzaaJDUaGVWtHOB2k6ouAHDr/E5qDGFG86Dd0R3sa5WCdWDHLqF0YQeuhzgEuDpCWoxmhlZU0K0sl6KTbPZ1LBPMJ9MO6AAMEpHKb8Fh4FhN5xhGXSjct4cOFEHHvn5LMSJIMJ/M+cA0IA0YhTc5sgx4J+KqjGbB7uy1dAASu5nTN5YJ5pP5E/AnEbkI+JeqljWeLKM5cCjX677umDLQZyVGJAnFJ/MkYIMYjLBTWrCZChW6pZqRiWVCMTIvqepnlTsiMiaCeoxmRMKBreyRZBJbt/VbihFBQplW0EFEZuCtjKfAucBFEVVlNAuSinZQ0LIX3f0WYkSUUGoyY4GjeA7gvkCnSAoymg9dy3I5Yt3XMU8oNZkfqeqiyh0Rsa4Ao8EcPFBARw5SYd3XMU8osbAXiUhbEenjxstc2gi6jBhnV9YqAFp1s8XDY51aazIicgvepMgkYDfQE7g7srKMWKdw63IAug842WclRqQJxSfTXVVHA39U1VOA30dYk9EcyMvkIG3p0cda37FOKEbmsHuvXFfXBjUYDaZ94Xp2tOyPxIXyFTSaMqF8wikici6wQ0Q2Az0irMmIccrLyuhTupVDHQbXntlo8oSyMt61ldsishDYGFFFRsyTs/kzUqWEuJ4j/ZZiNAK11mREpIOIPCAibwGXAW0iL8uIZfKzlgDQuf84n5UYjUEozaXngb3AC3iRBJ5vSIEu7MlMEblDRJ4XkW7V5DlZRF4Wkeki8kcRuSbg2DMi8kHAy1bqa2KU5qymRFvQe4DVZJoDoQzG26CqD1buiMgDDSzzt8B/VPXvztfzMF4XeSA9gMdVdbGIJAB7RORfqloA5KnqdQ3UYPjISfvXsL1FKumtEv2WYjQCodRkDlVGKBCRNkCu2/5ePcs8HuOa6mNgo6pvqOrigKQyoNRtJ4nIz0XkdhG5QUQstmlTQpVexZvYl2SdlM2FUH6gPwF+4YKwdQX2icjNeOFhX6nuBBF5F/hCMwi4i8/Hsj4IdBSRFkHWq7kB+K2qFrr9l/GCwZW5YHB3AvfWoONa4FqAPn1sjkw0UJC3nWQOUtFtmN9SjEYiFCPzG1V9smqiiwBZLap6dk3HRKQylvUBPEO1vyYDIyKXAG1V9Xj8bVVdHpBlPnA7NRgZVZ0JzAQYN26c1qTJaDxy1y8mGWiXZiuGNBdCmbv0BQPj0p+uZ5nHY1zjYmCDFwkhYC1hRORqoKuq3iciw0VkgEt/KOBaGUBWPXUYPnBk+woAUgaP91mJ0Vj44c/4GfCAMxr9gekufQRerKfhInIe8Du88LjfAjoDN+KN0eni1rcpwht9fEsj6zcaQMv8NeyUbvTq0NlvKUYjESxawWnAR6oa1maGqu6jmgiUqroSFzhOVecA7Ws4f1o49RiNS9cjG9ndZgC9/BZiNBrBmksXq6qKyDmBiSKSFlFFRsxy5NABelXs4ljnoX5LMRqRYEam1BmUsyrXknE+kx83ijIj5tixbilxorTuM8pvKUYjEswnsxCvZyYDGB2Q3ge4LZKijNikMHsZAN0H2hoyzYlgcZdeAV4RkXNV9c3KdBH5WqMoM2IOyctkP0l069XPbylGIxLKLOw3ReQMYCSwUlX/HXlZRizS4eAGdrbqT0dbQ6ZZEcos7F/idROnAtPdvmHUibLSEltDppkSyjiZlqp6vIkkIvdHUI8Ro2SvXUK6lNKilzl9mxuh1FvLq+xXREKIEdsULH+DChX6jv/CfFgjxgmlJlMmIm8AW/BG6H4aWUlGLNJp53w2JQxgYPfefksxGplQHL/3ichZeMP+31bV9yIvy4glCvK2M6BsIwtTbRmg5khIc5dUdR4wL8JajBhl6yf/IhnoMu5bfksxfMD6Eo2I02Lzu+SRTP+hE/yWYviAGRkjohQfPcLAw0vZ1vlUi7HUTAllnMxvG0OIEZts/HQubeQYiUO/4bcUwydC+WsZLiJ/EJEfi0jbiCsyYoqjn71NkbZi4CSbjdJcCcXx+x1VPSoiQ4A/iMhe4PequjXC2owmjlZUkFqwgA1txzK6tf0/NVdCqclMFZGxeAt2TwJ2AReKyM8iqsxo8mSvW0J38intf5bfUgwfCaUm8xKwAngS+IGqVgCIyOORFGY0fXYvmUNfoN+kC/yWYvhIKEbml5WLiYtIGxGpwJtqsDKiyowmT8ec/7KpRQYZPVP9lmL4SCjNpdYB2z2A51W1VFVfjJAmIwYoyM0mo3QDBT1P91uK4TPBFhLvA6QBg0RkskuOAxq0sLiIdAJm4M2FygB+pqq7q8mXDWS73Z2qeqlLTwN+iRcKJQ34qaoebogmI/xsevtRJgApk3/gtxTDZ4I1l0YD3wJGAeLSyoG3GlhmKLGwAWap6t3VpD8D3OXiZN+IF9zN1riJIo4cOsCQnf9g5UmnMibdIkU2d4ItvzkHmCMiJ6vqkjCW+XXgN277f8Cfasg3WURuw4s2OVdVPxGRBOB0YEnA+c9Rg5GxMLX+8Nlbv2cCR2hz+v/5LcWIAoI1l8TFXNodGNkRuE5Vg3ZfhykW9h2uttIGWC4i3wCOAEcDYkEddNerFgtT2/iUlZbQZ+Ms1iUMZfC4qX7LMaKAYM2lT4HxwIfAVk40mfrgRYGskXDEwlbVxe69SERW4oW0/QvQOsAAtgP2BNNiNC6r3v0TYzWfvPH3+C3FiBKCNZcqgxXfFOZoBZWxsHdQJRY2kKKq20VkKpCgqu+4c9KBzapaKiLvAycDiwPPN/xHKypov+JptsWlMPIrF/stx4gSQhkns6vSLyMiE4GPGlhmrbGw8Wond4vIGKAn8KqqfuzyXQfc5RbS6oPFwo4a1nzyFsPKN/Pp8LtJjY/3W44RJYRiZKYDv3bbR4D7gJvrW2CIsbAzgQtrOD8buLK+5RuRo+LjJyigAyO/dq3fUowoIpTBeEtVdS0c//Hviqwkoymy/tN5jChewqa075FokyGNAEIxMgNEJBnAvadHVpLR1Cg+eoQ2795MrnRlxIV3+C3HiDJCMTKzgBUiUggsA16IqCKjybFi9p30qdhJwZQHaZvUwW85RpQRSrSCT4DeIpKsqgWNoMloQmSt+h8n75zN4o5fY/xp5/stx4hCQll+s7uIvAy8LyKzRaS6QXZGM6S05Bjyxg0ckHYMvPwJv+UYUUoozaXfAHOAH+DNW5oRUUVGk2HpK/fQv3wL2yf9hvaduvgtx4hSQunCXq+qf3fby0WkXyQFGU2DDUvnM2bLTJYlncbYsy/zW44RxYRSk0l3yzNU9i6ZkWnm7Nyyhq5v/YCCuE70vfxpv+UYUU4oNZk/AatEJAkoBGy8eDNmf/4uKmZ/G1AqLvkHnbr28luSEeXUuXdJRFrXdo4RmxQXHWb3s+eTVpFP9tdfYVDGSL8lGU2AYEs9TK4mDbwFpr4wLcCIbcpKS1j3h4sZWbqelZMeY8z4M/2WZDQRgtVkHgVWcWKJh0qGR06OEY0cOXSArKcuYvTRxSwaOJ2J50zzW5LRhAhmZG5S1f9VTRSRUyKox4gyCvJ2sP+5bzGsdDOfDruLiRf91G9JRhMj2Hoy/wMQkZbAD4EEvIWsNjWONMNvdmxcSfwr36FXxQEyJz/DhKnm8zfqTihd2I8CnfDWbtnFiWUfjBhmxbyXSPrL10nUo+Sc9w9GmYEx6kkoRiZbVe8BdqnqFmBnhDUZPlJ0uJDFT1zG6E+upyC+G0cvf4cBY07zW5bRhAllnEw/EWkFqFsi0+YuxSgbl39I6zevY1zFLhb2vJyx0x6iZatEv2UZTZxgXdgjVHU1MA9vIXHFCy9icS5ijMK9u1n/yp2Mzf8XBdKJdWe9zKRTvu63LCNGCFaTeVpEfqaq/3KLd6cDWap6oJG0GRGmrLSEZa89ysB1TzBOj7C0y/kMuuRBhtpkRyOMBDMys4GeIvIssAH4UzgMTChhakVkCvAHIN8ldQX+rqp3i8gzwKCA7De6ZUGNEKkoL2flvD/TccmjTKjYxpqWI2lz3kNMGDrBb2lGDBKsC/sZt/mKiGQA/yci7YDXVPWDBpQZSpjaXOAyVV0BICLPAy+6Y3mqel0Dym+2lJWWsHLu83RZ+QfGVOxgh/Rk+cTHGX3W5UhcKH0AhlF3QnH8ghf4fg1wA3AZXpd2fak1TK2qbqzcdotktVLVbS4pSUR+DpThRU94prrgcMYJDu7dw9p3nqZ31suM091sjUtl2cm/Y9TZ0+jdItSvgGHUj2CO37PwHL4/BC4HsoA/An+t7aJhClNbyY+BZwL2XwZWq2qZiDwI3AncW4OOZhsLWysq2LRyAQc+eoYR+99jopSyPmEIy8f9klFnXEJfi4tkNBJyIqx0lQMi+XjjaF4G/hguv4eI7AC+pKo7nH8mS1WrrRm5rvPXVLXarg4ROQe4XVVPr63ccePG6dKlSxsivUmQu3U92z6cRc/tb5JakUORtiIz+RySp/yY/sMn+i3PiEFEZJmqjqvpeLC68lzgWlUtDrOmWsPUBuS9BHgl8GQReUhVb3W7GXg1rGbNzi1r2LHwn7TPfofBpWvpCaxNGMang69g8JlXMKFDZ78lGs2YYEbmRxEwMBBamNpKLgLOq3J+FxGZARQBA2mGYWrLSkvIWv4B+zPn0j13Pn0rsukFbIlLY2Hf60mb8gOGpA70W6ZhAEGaS7FGU24uaUUF2zeuJG/1f2i57UPSDy8jSY5SrsL6VsM4lHY2fSZdRM++g2q/mGGEmYY0lwyfKC05Rvbaxexb/zEJOQtJPbySVApJBfLowrrOZ9AiYyr9J3zDBs4ZUY8ZGZ/Rigp2blnL7g2LKN2xnHZ7V9G3ZCMZUgJAHslsbTeezamn0GPkGaT0G0p3G9NiNCHMyDQixUePsGPDcvZvXUHFrkySDqynd0kWKRSRApRoC7Ym9GdVt/NJSB1Pz2FfpnvvDDMqRpPGjEwEOFS4j12bMzmYs5bS3etJ3L+R5KPZ9KzYRYZ4PrCj2pIdCWms63wm0mMUnTIm0GfQWAbarGcjxjAjU08K9+WTv2MDhTs3UZK/ifgD2bQ9sp1uJTtI5gBJLl+pxrMzvif5bdPJ6fg1WvUcRpeMcfRMG8wAG21rNAPsW14NFeXl7C/Yxb5dWzm0O5uSvdlwYAetjuykXXEuXcrzaE8R7QPOKaADBQk92dLhS2zq2I9W3QfRuc8QevQbSlqrRNJ8uhfD8BszMo6iw4VkP/F12pfuoUvFXjpLGYFD2Iq0FfnxXTnQqgcFbUehHVJp1aUv7XoOpFvqQJLbdSTZN/WGEb2YkXG0bpNEhcSz66RhbD+pB9KuF6069yGpWxrJvdJp36krqXFxpPot1DCaGGZkHBIXx7A7P/RbhmHEHNY3ahhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRDEjYxhGRGl0IyMicSLyQxHZIyLDguQ7Q0SeEpG7ReRXAemdRGSmiNwhIs+7CJOGYUQpfkyQHAl8ihfSpFpEpA1e1MihqnpMRF4Vkamq+l9Ci6VtGEaU0Og1GVVdoaora8k2Cdimqsfc/v/wYmjj3hdWk24YRhQSkZpMsFjYqvpGCJcIjJcNXszsrtUcCxpLOzAWNnBYRDaEUHYyUBBCPr8wfQ0jmvVFszaoWV/QZZYiYmRU9ewGXmIPHF8mF6CdSws8dsCl76/OwDgdM4GZdSlYRJYGC1TlN6avYUSzvmjWBvXXF1W9SyLS120uBFJFpJXbPx4zmxOxtKumG4YRhTS641dEOgLXA+2Ba0XkL6q6SES6AB+LSH9VLRKRHwFPiEg+sNo5faHmWNqGYUQhjW5kVHU/cJ97BabnA70C9t8D3qvm/H3ANRGUWKfmlQ+YvoYRzfqiWRvUU5+oariFGIZhHCeqfDKGYcQeZmQMw4goFhLFISJnABfgdZGrqt7jsyREpDue72qkqp7s0hLxRjnvBDKAGaq60Qdt/Z225UAKsFdVfy0inYAZwBan72equtsHfXHAm3ijy1vidRJcCbSOBn1OY2unb56qTo+WzzZA3yKg2O2Wq+rUen2+qtrsX0AbIAto5fZfBaZGga5vA+cCSwPS7gBuc9vDgQU+aTsZOC9gfy0wFm86yHdc2rnAbJ/0xQG/CNifA1waLfpc+b8D/gQ8HE2fbYC+u6tJq/Pzs+aSR7BpDL6hqv/k8yOfIWBahapmAiNFpJ0P2pao6pyApDjgCFEy7UNVK1T1PgARaYFX29oQLfpE5Puu/K0ByVHx2QYwXERud5OU6z2tx5pLHsGmMUQbNWk96I8cEJHzgXdVdb2IhDzto5G0nQ38H/CWqi6NBn0iMgQYrKo/E5ERAYei7bN9QFUXi0g88JGIHKIO03oqsZqMR7BpDNFGVGkVkdOB0/F+yPB5fUGnfTQGqvquqp4D9BWRH0eJvvOBYhG5AzgVGC8iNxNln62qLnbv5cACvM+5zs/PajIex6cxuCbTKcBTPmuqicppFQtEZDiwSlV9+adzVegvAz8BeohIaoC+Hfg47cPVFvqqamX5W4F+0aBPVX8ToDMROElVH3Pb0fLZDgJOUdXnXVIG8Br1eH42GM8hImfiOVrzgVKNjt6l04DLgXOAp/EcheD1QOwC0oHfqj+9S2OBD4GlLqkt8AfgDeABYBtej84d6k/vUn/gIbzerwRgMHATUBIN+pzGC/Gm2LTEe3avEwWfrdPW02lajldjSQBuATpQx+dnRsYwjIhiPhnDMCKKGRnDMCKKGRnDMCKKGRnDMCKKGRnDMCKKjZMxIo6ILMCbCNgZbxLqH92hXng9nBf7pc2IPNaFbUQcEblCVV90wfzeUtW0ynRgltqXMKax5pIRcVT1xRoOJeEmCIrIFSKSJyK3ishsEZkrIt9xUUI/qpwoKCJDReTPLt/zItKvse7DqB9mZAzfUNUnArZfBNYDy1X1+8AxIElVrwJWAGe6rM8Bz6jqQ8BsToyCNqIU88kY0cZm934gYHs/JybljQDOEpHJeAtQHW5ceUZdMSNjNDVWAa+p6moXl+t8vwUZwTEjYzQKbqnJa4H2InKlqr7gll5oLyLfwwt/mgpME5E38Gos3xeRXGAy3gJKc4GrgJ+KyFagN/CSH/djhI71LhmGEVHM8WsYRkQxI2MYRkQxI2MYRkQxI2MYRkQxI2MYRkQxI2MYRkQxI2MYRkT5f9AeHAWEZceDAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(4*1, 4*1), sharex=True, sharey=True)\n",
    "ax_idx = [(i, j) for i in range(1) for j in range(1)]\n",
    "\n",
    "for i in tqdm(range(1)):\n",
    "    \n",
    "    ci = ax_idx[i]\n",
    "    \n",
    "    cstate, cconditionals, ctarget = gen_data((i+4)*(i+1), 1, 50)\n",
    "    \n",
    "    runner = jax.jit(jax.vmap(\n",
    "        partial(\n",
    "            hamiltonian_eom,\n",
    "            learned_dynamics(params, nn_forward_fn)), (0, 0), 0))\n",
    "\n",
    "    @jax.jit\n",
    "    def odefunc_learned(y, t):\n",
    "        return jnp.concatenate((runner(y[None, :2], y[None, [2]])[0], jnp.zeros(1)))\n",
    "\n",
    "    yt_learned = odeint(\n",
    "        odefunc_learned,\n",
    "        jnp.concatenate([cstate[0], cconditionals[0]]),\n",
    "        np.linspace(0, 1, 50),\n",
    "        mxsteps=100)\n",
    "    \n",
    "    cax = ax#[ci[0], ci[1]]\n",
    "    cax.plot(cstate[:, 1], label='Truth')\n",
    "    cax.plot(yt_learned[:, 1], label='Learned')\n",
    "    cax.legend()\n",
    "    if ci[1] == 0:\n",
    "        cax.set_ylabel('Velocity of particle/Speed of light')\n",
    "    if ci[0] == 0:\n",
    "        cax.set_xlabel('Time')\n",
    "        \n",
    "    cax.set_ylim(-1, 1)\n",
    "    \n",
    "plt.title(\"Hamiltonian NN - Special Relativity\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('sr_hnn_true.png', dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "main2",
   "language": "python",
   "name": "main2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
